diff --git a/CMakeLists.txt b/CMakeLists.txt index 50496efa1..c8e354f61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,7 @@ endif() option(BUILD_EXAMPLES "Build examples" TRUE) option(BUILD_FT "Build functional tests" TRUE) -option(BUILD_UT "Build unit tests" FALSE) +option(BUILD_REG_TESTS "Build regression tests" TRUE) option(BUILD_CONFIG "Build cmake configs" TRUE) option(ENABLE_MPI "Enable MPI for library" TRUE) option(ENABLE_MPI_TESTS "Enable MPI for tests" TRUE) @@ -70,17 +70,16 @@ message(STATUS "C compiler : ${CMAKE_C_COMPILER}") message(STATUS "CXX compiler : ${CMAKE_CXX_COMPILER}") message(STATUS "Build examples: ${BUILD_EXAMPLES}") message(STATUS "Build functional tests: ${BUILD_FT}") -message(STATUS "Build unit tests: ${BUILD_UT}") message(STATUS "Build cmake configs: ${BUILD_CONFIG}") -message(STATUS "Enable MPI for library: ${ENABLE_MPI}") -message(STATUS "Enable MPI for tests: ${ENABLE_MPI_TESTS}") -message(STATUS "Enable support for interop event functionality: ${ENABLE_SYCL_INTEROP_EVENT}") -message(STATUS "Enable support for OFI HMEM: ${ENABLE_OFI_HMEM}") +message(STATUS "Enable MPI support: ${ENABLE_MPI}") +message(STATUS "Enable MPI tests support: ${ENABLE_MPI_TESTS}") +message(STATUS "Enable SYCL interop event support: ${ENABLE_SYCL_INTEROP_EVENT}") +message(STATUS "Enable OFI HMEM support: ${ENABLE_OFI_HMEM}") add_definitions(-DCCL_C_COMPILER="${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") add_definitions(-DCCL_CXX_COMPILER="${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") -SET(MULTI_GPU_SUPPORT OFF CACHE BOOL "Enable Multi GPU extension support") +SET(CCL_ENABLE_ZE OFF CACHE BOOL "Enable Level Zero support") set(CCL_COMMON_INSTALL_PREFIX "intel64") set(CMAKE_INSTALL_LIBDIR "lib") @@ -94,11 +93,8 @@ set(CCL_INSTALL_LICENSE "${CMAKE_INSTALL_PREFIX}/licensing") set(CCL_INSTALL_MODULE "${CMAKE_INSTALL_PREFIX}/modulefiles") set(CCL_INSTALL_EXAMPLES "${CMAKE_INSTALL_PREFIX}/examples") set(CCL_INSTALL_TESTS "${CMAKE_INSTALL_PREFIX}/tests") -set(CCL_INSTALL_UNIT_TESTS "${CMAKE_INSTALL_PREFIX}/tests/unit") set(CCL_INSTALL_KERNELS "${CMAKE_INSTALL_PREFIX}/lib/kernels") -set(CCL_UNIT_TESTS_BUILD "${CMAKE_BINARY_DIR}/tests/unit") - # setup dependency directories set(DEPS_DIR "${PROJECT_SOURCE_DIR}/deps") @@ -133,25 +129,21 @@ if (${CMAKE_VERSION} VERSION_LESS 3.1) set(C_COMPILER_FLAGS "-std=gnu99") endif() -# TODO: add -Wextra to c/cxx flags - # common release/debug compilation settings -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_COMPILER_FLAGS} -Wall -Werror -D_GNU_SOURCE -fvisibility=internal") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_COMPILER_FLAGS} -Wall -Wextra -Wno-unused-parameter -Wno-implicit-fallthrough -Werror -D_GNU_SOURCE -fvisibility=internal") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${C_COMPILER_FLAGS} -O0 -g -DENABLE_DEBUG") set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${C_COMPILER_FLAGS} -O3") set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} ${C_COMPILER_FLAGS} -O2 -g") set(CMAKE_C_STANDARD 99) set(CMAKE_C_STANDARD_REQUIRED ON) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMPILER_FLAGS} -Wall -Werror -D_GNU_SOURCE -fvisibility=internal") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMPILER_FLAGS} -Wall -Wextra -Wno-unused-parameter -Wno-implicit-fallthrough -Werror -D_GNU_SOURCE -fvisibility=internal") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${CXX_COMPILER_FLAGS} -O0 -g -DENABLE_DEBUG") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CXX_COMPILER_FLAGS} -O3") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} ${CXX_COMPILER_FLAGS} -O2 -g") set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(TRY_ENABLE_SYCL_L0 ON) - set(COMMON_CMAKE_DIR ${PROJECT_SOURCE_DIR}/cmake) if (COMPUTE_BACKEND) message(STATUS "COMPUTE_BACKEND: ${COMPUTE_BACKEND}") @@ -192,6 +184,7 @@ if (WITH_ASAN AND ${CMAKE_BUILD_TYPE_CASE_INSENSITIVE} STREQUAL "debug") endif() set_lp_env() +set_avx_env() set(CCL_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/src) @@ -227,7 +220,7 @@ file(GLOB spv_kernels "${PROJECT_SOURCE_DIR}/src/kernels/kernels.spv") endif() set(CCL_MAJOR_VERSION "2021") -set(CCL_MINOR_VERSION "4") +set(CCL_MINOR_VERSION "5") set(CCL_UPDATE_VERSION "0") set(CCL_PRODUCT_STATUS "Gold") string(TIMESTAMP CCL_PRODUCT_BUILD_DATE "%Y-%m-%dT %H:%M:%SZ") @@ -257,7 +250,9 @@ if (ENABLE_MPI_TESTS) add_subdirectory(examples/benchmark) add_subdirectory(examples/common) add_subdirectory(examples/cpu) - add_subdirectory(examples/external_launcher) + if (BUILD_CONFIG) + add_subdirectory(examples/external_launcher) + endif() if (CCL_ENABLE_SYCL) add_subdirectory(examples/sycl) endif() @@ -265,7 +260,4 @@ if (ENABLE_MPI_TESTS) if (BUILD_FT) add_subdirectory(tests/functional) endif() - if (BUILD_UT AND EXISTS "${PROJECT_SOURCE_DIR}/tests/unit") - add_subdirectory(tests/unit) - endif() endif() diff --git a/cmake/helpers.cmake b/cmake/helpers.cmake index 245de34e1..8bba8e98e 100644 --- a/cmake/helpers.cmake +++ b/cmake/helpers.cmake @@ -88,6 +88,38 @@ function(set_lp_env) endfunction(set_lp_env) +function(set_avx_env) + + set(GCC_AVX_MIN_SUPPORTED "4.9.0") + set(CLANG_AVX_MIN_SUPPORTED "9.0.0") + + if (${CMAKE_C_COMPILER_ID} STREQUAL "Intel" + OR (${CMAKE_C_COMPILER_ID} STREQUAL "Clang" + AND NOT ${CMAKE_C_COMPILER_VERSION} VERSION_LESS ${CLANG_AVX_MIN_SUPPORTED}) + OR (${CMAKE_C_COMPILER_ID} STREQUAL "GNU" + AND NOT ${CMAKE_C_COMPILER_VERSION} VERSION_LESS ${GCC_AVX_MIN_SUPPORTED}) + ) + add_definitions(-DCCL_AVX_COMPILER) + set(CCL_AVX_COMPILER ON) + else() + set(CCL_AVX_COMPILER OFF) + endif() + message(STATUS "AVX compiler: ${CCL_AVX_COMPILER}") + + if (CCL_AVX_COMPILER) + if ((${CMAKE_C_COMPILER_ID} STREQUAL "Clang" OR ${CMAKE_C_COMPILER_ID} STREQUAL "GNU")) + add_definitions(-DCCL_AVX_TARGET_ATTRIBUTES) + set(CCL_AVX_TARGET_ATTRIBUTES ON) + else() + set(CCL_AVX_TARGET_ATTRIBUTES OFF) + endif() + message(STATUS "AVX target attributes: ${CCL_AVX_TARGET_ATTRIBUTES}") + endif() + + set(AVX_ENV_DEFINED 1 PARENT_SCOPE) + +endfunction(set_avx_env) + function(check_compiler_version) set(GCC_MIN_SUPPORTED "4.8") @@ -293,11 +325,11 @@ function(set_compute_backend COMMON_CMAKE_DIR) endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COMPUTE_BACKEND_FLAGS}") if (${COMPUTE_BACKEND_TARGET_NAME} STREQUAL "Intel::SYCL_level_zero" OR ${COMPUTE_BACKEND_TARGET_NAME} STREQUAL "ze_loader") - set(MULTI_GPU_SUPPORT ON PARENT_SCOPE) - set(MULTI_GPU_SUPPORT ON) + set(CCL_ENABLE_ZE ON PARENT_SCOPE) + set(CCL_ENABLE_ZE ON) endif() - if (MULTI_GPU_SUPPORT) - message(STATUS "Enable GPU support using level-zero") + if (CCL_ENABLE_ZE) + message(STATUS "Enable Level Zero support") endif() # need to pass these variables to overlying function diff --git a/cmake/templates/oneCCLConfig.cmake.in b/cmake/templates/oneCCLConfig.cmake.in index 86b7de9f8..0decd0c19 100644 --- a/cmake/templates/oneCCLConfig.cmake.in +++ b/cmake/templates/oneCCLConfig.cmake.in @@ -23,11 +23,11 @@ if (EXISTS "${CCL_CONFIGURATION}") set(_oneccl_subdir "${CCL_CONFIGURATION}") endif() -if (_oneccl_subdir EQUAL "cpu_icc") +if (_oneccl_subdir EQUAL "cpu") include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-fsycl" _fsycl_option) if (_fsycl_option) - message(STATUS "STATUS: -fsycl not supported for CCL_CONFIGURATION=cpu_icc") + message(STATUS "STATUS: -fsycl not supported for CCL_CONFIGURATION=cpu") endif() endif() diff --git a/deps/mpi/bin/hydra_bstrap_proxy b/deps/mpi/bin/hydra_bstrap_proxy index 6a2e27a5b..665e28ffd 100755 Binary files a/deps/mpi/bin/hydra_bstrap_proxy and b/deps/mpi/bin/hydra_bstrap_proxy differ diff --git a/deps/mpi/bin/hydra_nameserver b/deps/mpi/bin/hydra_nameserver index 3af2dc9bc..5d91b8cae 100755 Binary files a/deps/mpi/bin/hydra_nameserver and b/deps/mpi/bin/hydra_nameserver differ diff --git a/deps/mpi/bin/hydra_pmi_proxy b/deps/mpi/bin/hydra_pmi_proxy index 6e09d880f..644a5a5e7 100755 Binary files a/deps/mpi/bin/hydra_pmi_proxy and b/deps/mpi/bin/hydra_pmi_proxy differ diff --git a/deps/mpi/bin/mpiexec b/deps/mpi/bin/mpiexec index 61a4ff30a..2fca15c37 100755 Binary files a/deps/mpi/bin/mpiexec and b/deps/mpi/bin/mpiexec differ diff --git a/deps/mpi/bin/mpiexec.hydra b/deps/mpi/bin/mpiexec.hydra index 61a4ff30a..2fca15c37 100755 Binary files a/deps/mpi/bin/mpiexec.hydra and b/deps/mpi/bin/mpiexec.hydra differ diff --git a/deps/mpi/bin/mpigcc b/deps/mpi/bin/mpigcc index c54304f11..338750e29 100755 --- a/deps/mpi/bin/mpigcc +++ b/deps/mpi/bin/mpigcc @@ -99,12 +99,8 @@ fi # Determined by a combination of environment variables and tests within # configure (e.g., determining whehter -lsocket is needee) CC="gcc" -MPICH_VERSION="3.3" -CFLAGS="" -CPPFLAGS="" -LDFLAGS=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " -LIBS="-lm -lpthread -lfabric -lrt " -MPIVERSION="2021.4" +MPICH_VERSION="3.4a2" +MPIVERSION="2021.5" MPILIBNAME="mpi" @@ -594,10 +590,6 @@ fi final_cppflags=" " final_ldflags=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " final_libs="-lpthread -lrt " -if test "no" = "no" -o "${interlib_deps}" = "no" ; then - final_ldflags="${final_ldflags} -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl -L/p/pdsd/scratch/jenkins/artefacts_impi_2019/hcoll/lib -L/p/pdsd/scratch/Uploads/IMPI/other/software/libfabric/linux/v1.9.0/lib" - final_libs="${final_libs} -lm -lpthread -lfabric -lrt " -fi # ----------------------------------------------------------------------- # @@ -622,7 +614,7 @@ if [ "$linking" = yes ] ; then $Show $CC ${final_cppflags} $PROFILE_INCPATHS ${final_cflags} ${final_ldflags} $allargs -I\"${includedir}\" rc=$? else - $Show $CC $CPPFLAGS $CFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $mpilibs $I_MPI_OTHERLIBS $LDFLAGS + $Show $CC $CPPFLAGS $CFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $mpilibs $I_MPI_OTHERLIBS ${final_ldflags} rc=$? if [ $rc -eq 0 -a "x$strip_debug_info" = "xyes" ] ; then diff --git a/deps/mpi/bin/mpigxx b/deps/mpi/bin/mpigxx index b9382fd8c..65ac0f5d3 100755 --- a/deps/mpi/bin/mpigxx +++ b/deps/mpi/bin/mpigxx @@ -97,11 +97,8 @@ fi # Default settings for compiler, flags, and libraries CXX="g++" -MPICH_VERSION="3.3" -CXXFLAGS="" -LDFLAGS=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " -LIBS="-lm -lpthread -lfabric -lrt " -MPIVERSION="2021.4" +MPICH_VERSION="3.4a2" +MPIVERSION="2021.5" MPILIBNAME="mpi" MPICXXLIBNAME="mpicxx" @@ -606,10 +603,6 @@ fi final_cppflags=" " final_ldflags=" -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl " final_libs="-lpthread -lrt " -if test "no" = "no" -o "${interlib_deps}" = "no" ; then - final_ldflags="${final_ldflags} -Wl,-z,now -Wl,-z,relro -Wl,-z,noexecstack -Xlinker --enable-new-dtags -ldl -L/p/pdsd/scratch/jenkins/artefacts_impi_2019/hcoll/lib -L/p/pdsd/scratch/Uploads/IMPI/other/software/libfabric/linux/v1.9.0/lib" - final_libs="${final_libs} -lm -lpthread -lfabric -lrt " -fi # A temporary statement to invoke the compiler # Place the -L before any args incase there are any mpi libraries in there. @@ -625,7 +618,7 @@ if [ "$linking" = yes ] ; then $Show $CXX ${final_cppflags} $PROFILE_INCPATHS ${final_cxxflags} ${final_ldflags} $allargs -I\"${includedir}\" rc=$? else - $Show $CXX $CXXFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $shllibpath $cxxlibs $mpilibs $I_MPI_OTHERLIBS $LDFLAGS + $Show $CXX $CXXFLAGS $allargs -I\"${includedir}\" -L\"${libdir}${MPILIBDIR}\" -L\"${libdir}\" $rpath_opt $shllibpath $cxxlibs $mpilibs $I_MPI_OTHERLIBS ${final_ldflags} rc=$? if [ $rc -eq 0 -a "x$strip_debug_info" = "xyes" ] ; then $Show objcopy --only-keep-debug ${executable} ${executable}.dbg diff --git a/deps/mpi/bin/mpiicc b/deps/mpi/bin/mpiicc index 25c4dea5b..581e5c29a 100755 --- a/deps/mpi/bin/mpiicc +++ b/deps/mpi/bin/mpiicc @@ -106,7 +106,7 @@ LDFLAGS="-ldl" MPILIBNAME="mpi" # MPIVERSION is the version of the MPICH2 library that mpicc is intended for -MPIVERSION="2021.4" +MPIVERSION="2021.5" # # Internal variables # Show is set to echo to cause the compilation command to be echoed instead diff --git a/deps/mpi/bin/mpiicpc b/deps/mpi/bin/mpiicpc index 1e221dbbd..e2377755f 100755 --- a/deps/mpi/bin/mpiicpc +++ b/deps/mpi/bin/mpiicpc @@ -107,7 +107,7 @@ MPILIBNAME="mpi" MPICXXLIBNAME="mpicxx" # MPIVERSION is the version of the Intel(R) MPI Library that mpiicpc is intended for -MPIVERSION="2021.4" +MPIVERSION="2021.5" # Internal variables # Show is set to echo to cause the compilation command to be echoed instead diff --git a/deps/mpi/etc/tuning_clx-ap_ofi.dat b/deps/mpi/etc/tuning_clx-ap_ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_clx-ap_shm-ofi.dat b/deps/mpi/etc/tuning_clx-ap_shm-ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_clx-ap_shm.dat b/deps/mpi/etc/tuning_clx-ap_shm.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_generic_ofi.dat b/deps/mpi/etc/tuning_generic_ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_generic_shm-ofi.dat b/deps/mpi/etc/tuning_generic_shm-ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_generic_shm.dat b/deps/mpi/etc/tuning_generic_shm.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_knl_ofi.dat b/deps/mpi/etc/tuning_knl_ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_knl_shm-ofi.dat b/deps/mpi/etc/tuning_knl_shm-ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_knl_shm.dat b/deps/mpi/etc/tuning_knl_shm.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_skx_ofi.dat b/deps/mpi/etc/tuning_skx_ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_skx_shm-ofi.dat b/deps/mpi/etc/tuning_skx_shm-ofi.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/etc/tuning_skx_shm.dat b/deps/mpi/etc/tuning_skx_shm.dat old mode 100755 new mode 100644 diff --git a/deps/mpi/include/mpi.h b/deps/mpi/include/mpi.h old mode 100755 new mode 100644 index 3dc48685b..adc1f2297 --- a/deps/mpi/include/mpi.h +++ b/deps/mpi/include/mpi.h @@ -580,8 +580,8 @@ typedef int (MPI_Delete_function) ( MPI_Comm, int, void *, void * ); * digits for REV, 1 digit for EXT and 2 digits for EXT_NUMBER. So, * 2019.0.0b0 will have the numeric version 20190000100. */ -#define I_MPI_VERSION "2021.4.0" -#define I_MPI_NUMVERSION 20210400300 +#define I_MPI_VERSION "2021.5.0" +#define I_MPI_NUMVERSION 20210500300 /* for the datatype decoders */ enum MPIR_Combiner_enum { diff --git a/deps/mpi/include/mpicxx.h b/deps/mpi/include/mpicxx.h old mode 100755 new mode 100644 diff --git a/deps/mpi/include/mpio.h b/deps/mpi/include/mpio.h old mode 100755 new mode 100644 diff --git a/deps/mpi/lib/libmpi.so b/deps/mpi/lib/libmpi.so index 84631e5a7..5b05a5027 100755 Binary files a/deps/mpi/lib/libmpi.so and b/deps/mpi/lib/libmpi.so differ diff --git a/deps/mpi/lib/libmpi.so.12 b/deps/mpi/lib/libmpi.so.12 index 84631e5a7..5b05a5027 100755 Binary files a/deps/mpi/lib/libmpi.so.12 and b/deps/mpi/lib/libmpi.so.12 differ diff --git a/deps/mpi/lib/libmpi.so.12.0 b/deps/mpi/lib/libmpi.so.12.0 index 84631e5a7..5b05a5027 100755 Binary files a/deps/mpi/lib/libmpi.so.12.0 and b/deps/mpi/lib/libmpi.so.12.0 differ diff --git a/deps/mpi/lib/libmpi.so.12.0.0 b/deps/mpi/lib/libmpi.so.12.0.0 index 84631e5a7..5b05a5027 100755 Binary files a/deps/mpi/lib/libmpi.so.12.0.0 and b/deps/mpi/lib/libmpi.so.12.0.0 differ diff --git a/deps/mpi/lib/libmpifort.so b/deps/mpi/lib/libmpifort.so index 00d80af4b..399678958 100755 Binary files a/deps/mpi/lib/libmpifort.so and b/deps/mpi/lib/libmpifort.so differ diff --git a/deps/mpi/lib/libmpifort.so.12 b/deps/mpi/lib/libmpifort.so.12 index 00d80af4b..399678958 100755 Binary files a/deps/mpi/lib/libmpifort.so.12 and b/deps/mpi/lib/libmpifort.so.12 differ diff --git a/deps/mpi/lib/libmpifort.so.12.0 b/deps/mpi/lib/libmpifort.so.12.0 index 00d80af4b..399678958 100755 Binary files a/deps/mpi/lib/libmpifort.so.12.0 and b/deps/mpi/lib/libmpifort.so.12.0 differ diff --git a/deps/mpi/lib/libmpifort.so.12.0.0 b/deps/mpi/lib/libmpifort.so.12.0.0 index 00d80af4b..399678958 100755 Binary files a/deps/mpi/lib/libmpifort.so.12.0.0 and b/deps/mpi/lib/libmpifort.so.12.0.0 differ diff --git a/deps/mpi/licensing/license.txt b/deps/mpi/licensing/license.txt index ffffdc860..f987e502b 100644 --- a/deps/mpi/licensing/license.txt +++ b/deps/mpi/licensing/license.txt @@ -1,77 +1,73 @@ -Intel Simplified Software License (Version February 2020) +Intel Simplified Software License (Version August 2021) -Use and Redistribution. You may use and redistribute the software (the +Use and Redistribution. You may use and redistribute the software (the "Software"), without modification, provided the following conditions are met: -* Redistributions must reproduce the above copyright notice and the following - terms of use in the Software and in the documentation and/or other materials +* Redistributions must reproduce the above copyright notice and the following + terms of use in the Software and in the documentation and/or other materials provided with the distribution. * Neither the name of Intel nor the names of its suppliers may be used to - endorse or promote products derived from this Software without specific prior - written permission. -* No reverse engineering, decompilation, or disassembly of this Software is + endorse or promote products derived from this Software without specific + prior written permission. +* No reverse engineering, decompilation, or disassembly of this Software is permitted. -Limited patent license. Intel grants you a world-wide, royalty-free, -non-exclusive license under patents it now or hereafter owns or controls to -make, have made, use, import, offer to sell and sell ("Utilize") this Software, -but solely to the extent that any such patent is necessary to Utilize the -Software alone. The patent license shall not apply to any combinations which -include this software. No hardware per se is licensed hereunder. +No other licenses. Except as provided in the preceding section, Intel grants no +licenses or other rights by implication, estoppel or otherwise to, patent, +copyright, trademark, trade name, service mark or other intellectual property +licenses or rights of Intel. -Third party programs. The Software may contain Third Party Programs. "Third -Party Programs" are third party software, open source software or other Intel -software listed in the "third-party-programs.txt" or other similarly named text -file that is included with the Software. Third Party Programs, even if included -with the distribution of the Software, may be governed by separate license -terms, including without limitation, third party license terms, open source -software notices and terms, and/or other Intel software license terms. These -separate license terms may govern your use of the Third Party Programs. +Third party software. The Software may contain Third Party Software. "Third +Party Software" is open source software, third party software, or other Intel +software that may be identified in the Software itself or in the files (if any) +listed in the "third-party-software.txt" or similarly named text file included +with the Software. Third Party Software, even if included with the distribution +of the Software, may be governed by separate license terms, including without +limitation, open source software license terms, third party software license +terms, and other Intel software license terms. Those separate license terms +solely govern your use of the Third Party Software, and nothing in this license +limits any rights under, or grants rights that supersede, the terms of the +applicable license terms. -DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED -WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE -DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS -WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE -THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND -ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT -INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. +DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE +DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS +WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE +THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND +ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT +INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE SOFTWARE. -LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE -OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD -INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR -UNAUTHORIZED USE OF THE SOFTWARE. +LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -No support. Intel may make changes to the Software, at any time without notice, -and is not obligated to support, update or provide training for the Software. +No support. Intel may make changes to the Software, at any time without notice, +and is not obligated to support, update or provide training for the Software. -Termination. Intel may terminate your right to use the Software in the event of -your breach of this Agreement and you fail to cure the breach within a -reasonable period of time. +Termination. Your right to use the Software is terminated in the event of your +breach of this license. -Feedback. Should you provide Intel with comments, modifications, corrections, -enhancements or other input ("Feedback") related to the Software Intel will be -free to use, disclose, reproduce, license or otherwise distribute or exploit the -Feedback in its sole discretion without any obligations or restrictions of any -kind, including without limitation, intellectual property rights or licensing +Feedback. Should you provide Intel with comments, modifications, corrections, +enhancements or other input ("Feedback") related to the Software, Intel will be +free to use, disclose, reproduce, license or otherwise distribute or exploit the +Feedback in its sole discretion without any obligations or restrictions of any +kind, including without limitation, intellectual property rights or licensing obligations. -Compliance with laws. You agree to comply with all relevant laws and regulations -governing your use, transfer, import or export (or prohibition thereof) of the +Compliance with laws. You agree to comply with all relevant laws and regulations +governing your use, transfer, import or export (or prohibition thereof) of the Software. -Governing law. All disputes will be governed by the laws of the United States of -America and the State of Delaware without reference to conflict of law -principles and subject to the exclusive jurisdiction of the state or federal -courts sitting in the State of Delaware, and each party agrees that it submits -to the personal jurisdiction and venue of those courts and waives any -objections. The United Nations Convention on Contracts for the International -Sale of Goods (1980) is specifically excluded and will not apply to the +Governing law. All disputes will be governed by the laws of the United States of +America and the State of Delaware without reference to conflict of law +principles and subject to the exclusive jurisdiction of the state or federal +courts sitting in the State of Delaware, and each party agrees that it submits +to the personal jurisdiction and venue of those courts and waives any +objections. The United Nations Convention on Contracts for the International +Sale of Goods (1980) is specifically excluded and will not apply to the Software. - -*Other names and brands may be claimed as the property of others. diff --git a/deps/mpi/licensing/third-party-programs.txt b/deps/mpi/licensing/third-party-programs.txt index f85123769..12d94d3db 100644 --- a/deps/mpi/licensing/third-party-programs.txt +++ b/deps/mpi/licensing/third-party-programs.txt @@ -1,4 +1,4 @@ -Intel(R) MPI Library 2021.4 Third Party Programs File +Intel(R) MPI Library 2021.5 Third Party Programs File This file is the "third-party-programs.txt" file specified in the associated Intel end user license agreement for the Intel software you are licensing. @@ -270,87 +270,82 @@ terms are listed below. ------------------------------------------------------------------------------- -5. Intel® Distribution for Python +5. Intel® Distribution for Python* - Intel Simplified Software License (Version February 2020) + Intel Simplified Software License (Version August 2021) - Use and Redistribution. You may use and redistribute the software (the + Use and Redistribution. You may use and redistribute the software (the "Software"), without modification, provided the following conditions are met: - * Redistributions must reproduce the above copyright notice and the following - terms of use in the Software and in the documentation and/or other materials - provided with the distribution. - * Neither the name of Intel nor the names of its suppliers may be used to - endorse or promote products derived from this Software without specific prior - written permission. - * No reverse engineering, decompilation, or disassembly of this Software is - permitted. - - Limited patent license. Intel grants you a world-wide, royalty-free, - non-exclusive license under patents it now or hereafter owns or controls to - make, have made, use, import, offer to sell and sell ("Utilize") this Software, - but solely to the extent that any such patent is necessary to Utilize the - Software alone. The patent license shall not apply to any combinations which - include this software. No hardware per se is licensed hereunder. - - Third party programs. The Software may contain Third Party Programs. "Third - Party Programs" are third party software, open source software or other Intel - software listed in the "third-party-programs.txt" or other similarly named text - file that is included with the Software. Third Party Programs, even if included - with the distribution of the Software, may be governed by separate license - terms, including without limitation, third party license terms, open source - software notices and terms, and/or other Intel software license terms. These - separate license terms may govern your use of the Third Party Programs. - - DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED - WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE - DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS - WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE - THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND - ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT - INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. - - LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, - INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE - OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD - INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR - UNAUTHORIZED USE OF THE SOFTWARE. - - No support. Intel may make changes to the Software, at any time without notice, - and is not obligated to support, update or provide training for the Software. - - Termination. Intel may terminate your right to use the Software in the event of - your breach of this Agreement and you fail to cure the breach within a - reasonable period of time. - - Feedback. Should you provide Intel with comments, modifications, corrections, - enhancements or other input ("Feedback") related to the Software Intel will be - free to use, disclose, reproduce, license or otherwise distribute or exploit the - Feedback in its sole discretion without any obligations or restrictions of any - kind, including without limitation, intellectual property rights or licensing + * Redistributions must reproduce the above copyright notice and the following + terms of use in the Software and in the documentation and/or other materials + provided with the distribution. + * Neither the name of Intel nor the names of its suppliers may be used to + endorse or promote products derived from this Software without specific + prior written permission. + * No reverse engineering, decompilation, or disassembly of this Software is + permitted. + + No other licenses. Except as provided in the preceding section, Intel grants no + licenses or other rights by implication, estoppel or otherwise to, patent, + copyright, trademark, trade name, service mark or other intellectual property + licenses or rights of Intel. + + Third party software. The Software may contain Third Party Software. "Third + Party Software" is open source software, third party software, or other Intel + software that may be identified in the Software itself or in the files (if any) + listed in the "third-party-software.txt" or similarly named text file included + with the Software. Third Party Software, even if included with the distribution + of the Software, may be governed by separate license terms, including without + limitation, open source software license terms, third party software license + terms, and other Intel software license terms. Those separate license terms + solely govern your use of the Third Party Software, and nothing in this license + limits any rights under, or grants rights that supersede, the terms of the + applicable license terms. + + DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE + DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS + WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE + THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND + ATTORNEYS' FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT + INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE SOFTWARE. + + LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + No support. Intel may make changes to the Software, at any time without notice, + and is not obligated to support, update or provide training for the Software. + + Termination. Your right to use the Software is terminated in the event of your + breach of this license. + + Feedback. Should you provide Intel with comments, modifications, corrections, + enhancements or other input ("Feedback") related to the Software, Intel will be + free to use, disclose, reproduce, license or otherwise distribute or exploit the + Feedback in its sole discretion without any obligations or restrictions of any + kind, including without limitation, intellectual property rights or licensing obligations. - Compliance with laws. You agree to comply with all relevant laws and regulations - governing your use, transfer, import or export (or prohibition thereof) of the + Compliance with laws. You agree to comply with all relevant laws and regulations + governing your use, transfer, import or export (or prohibition thereof) of the Software. - Governing law. All disputes will be governed by the laws of the United States of - America and the State of Delaware without reference to conflict of law - principles and subject to the exclusive jurisdiction of the state or federal - courts sitting in the State of Delaware, and each party agrees that it submits - to the personal jurisdiction and venue of those courts and waives any - objections. The United Nations Convention on Contracts for the International - Sale of Goods (1980) is specifically excluded and will not apply to the + Governing law. All disputes will be governed by the laws of the United States of + America and the State of Delaware without reference to conflict of law + principles and subject to the exclusive jurisdiction of the state or federal + courts sitting in the State of Delaware, and each party agrees that it submits + to the personal jurisdiction and venue of those courts and waives any + objections. The United Nations Convention on Contracts for the International + Sale of Goods (1980) is specifically excluded and will not apply to the Software. - - *Other names and brands may be claimed as the property of others. - - + ------------------------------------------------------------------------------- 6. uthash @@ -481,42 +476,110 @@ terms are listed below. ------------------------------------------------------------------------------- -10. PMIx - Copyright (c) 2019, PMIx - All rights reserved. +10. OpenPMIx + Most files in this release are marked with the copyrights of the +organizations who have edited them. The copyrights below are in no +particular order and generally reflect members of the Open MPI core +team who have contributed code that may or may not have been ported +to PMIx. Per the terms of that LICENSE, we include the list here. +The copyrights for code used under license from other parties +are included in the corresponding files. - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: +Copyright (c) 2004-2010 The Trustees of Indiana University and Indiana + University Research and Technology + Corporation. All rights reserved. +Copyright (c) 2004-2010 The University of Tennessee and The University + of Tennessee Research Foundation. All rights + reserved. +Copyright (c) 2004-2010 High Performance Computing Center Stuttgart, + University of Stuttgart. All rights reserved. +Copyright (c) 2004-2008 The Regents of the University of California. + All rights reserved. +Copyright (c) 2006-2010 Los Alamos National Security, LLC. All rights + reserved. +Copyright (c) 2006-2010 Cisco Systems, Inc. All rights reserved. +Copyright (c) 2006-2010 Voltaire, Inc. All rights reserved. +Copyright (c) 2006-2011 Sandia National Laboratories. All rights reserved. +Copyright (c) 2006-2010 Sun Microsystems, Inc. All rights reserved. + Use is subject to license terms. +Copyright (c) 2006-2010 The University of Houston. All rights reserved. +Copyright (c) 2006-2009 Myricom, Inc. All rights reserved. +Copyright (c) 2007-2008 UT-Battelle, LLC. All rights reserved. +Copyright (c) 2007-2019 IBM Corporation. All rights reserved. +Copyright (c) 1998-2005 Forschungszentrum Juelich, Juelich Supercomputing + Centre, Federal Republic of Germany +Copyright (c) 2005-2008 ZIH, TU Dresden, Federal Republic of Germany +Copyright (c) 2007 Evergrid, Inc. All rights reserved. +Copyright (c) 2008 Chelsio, Inc. All rights reserved. +Copyright (c) 2008-2009 Institut National de Recherche en + Informatique. All rights reserved. +Copyright (c) 2007 Lawrence Livermore National Security, LLC. + All rights reserved. +Copyright (c) 2007-2019 Mellanox Technologies. All rights reserved. +Copyright (c) 2006-2010 QLogic Corporation. All rights reserved. +Copyright (c) 2008-2010 Oak Ridge National Labs. All rights reserved. +Copyright (c) 2006-2010 Oracle and/or its affiliates. All rights reserved. +Copyright (c) 2009 Bull SAS. All rights reserved. +Copyright (c) 2010 ARM ltd. All rights reserved. +Copyright (c) 2010-2011 Alex Brick . All rights reserved. +Copyright (c) 2012 The University of Wisconsin-La Crosse. All rights + reserved. +Copyright (c) 2013-2019 Intel, Inc. All rights reserved. +Copyright (c) 2011-2014 NVIDIA Corporation. All rights reserved. +Copyright (c) 2019 Amazon.com, Inc. or its affiliates. All Rights + reserved. - 1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +$COPYRIGHT$ + +Additional copyrights may follow + +$HEADER$ + +The following LICENSE pertains to both PMIx and any code ported +from Open MPI. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +- Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer listed + in this license in the documentation and/or other materials + provided with the distribution. + +- Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +The copyright holders provide no reassurances that the source code +provided does not infringe any patent, copyright, or any other +intellectual property rights of third parties. The copyright holders +disclaim any liability to any recipient for claims brought against +recipient by any third party for infringement of that parties +intellectual property rights. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------------------------------------------------------------------------------- The following third party programs have their own third party programs. These additional third party program files are as follows: - 1. Intel(R) MPI Benchmarks https://raw.githubusercontent.com/intel/mpi-benchmarks/master/license/third-party-programs.txt - 2. Intel(R) Distribution for Python: third-party-programs-python.txt file + 1. Intel(R) MPI Benchmarks /mpi/latest/benchmarks/imb/license/third-party-programs.txt + 2. Intel(R) Distribution for Python* /intelpython/latest/licensing/third-party-programs.txt ------------------------------------------------------------------------------- -Other names and brands may be claimed as the property of others. \ No newline at end of file +* Other names and brands may be claimed as the property of others. \ No newline at end of file diff --git a/deps/ofi/bin/fi_info b/deps/ofi/bin/fi_info index b4df1a8e6..711ae57b3 100755 Binary files a/deps/ofi/bin/fi_info and b/deps/ofi/bin/fi_info differ diff --git a/deps/ofi/include/rdma/fabric.h b/deps/ofi/include/rdma/fabric.h index cdfa11e8d..21bffa1d6 100644 --- a/deps/ofi/include/rdma/fabric.h +++ b/deps/ofi/include/rdma/fabric.h @@ -80,7 +80,7 @@ extern "C" { #define FI_MAJOR_VERSION 1 #define FI_MINOR_VERSION 13 -#define FI_REVISION_VERSION 0 +#define FI_REVISION_VERSION 2 enum { FI_PATH_MAX = 256, diff --git a/deps/ofi/lib/libfabric.so b/deps/ofi/lib/libfabric.so index da151da5d..cf435ab98 100755 Binary files a/deps/ofi/lib/libfabric.so and b/deps/ofi/lib/libfabric.so differ diff --git a/deps/ofi/lib/libfabric.so.1 b/deps/ofi/lib/libfabric.so.1 index da151da5d..cf435ab98 100755 Binary files a/deps/ofi/lib/libfabric.so.1 and b/deps/ofi/lib/libfabric.so.1 differ diff --git a/deps/ofi/lib/prov/libpsm3-fi.so b/deps/ofi/lib/prov/libpsm3-fi.so index fab9b8d4e..875cedbc8 100755 Binary files a/deps/ofi/lib/prov/libpsm3-fi.so and b/deps/ofi/lib/prov/libpsm3-fi.so differ diff --git a/deps/ofi/lib/prov/libpsmx2-fi.so b/deps/ofi/lib/prov/libpsmx2-fi.so index 28235ef3a..edb03e004 100755 Binary files a/deps/ofi/lib/prov/libpsmx2-fi.so and b/deps/ofi/lib/prov/libpsmx2-fi.so differ diff --git a/deps/ofi/lib/prov/librxm-fi.so b/deps/ofi/lib/prov/librxm-fi.so index 99a542183..211edb301 100755 Binary files a/deps/ofi/lib/prov/librxm-fi.so and b/deps/ofi/lib/prov/librxm-fi.so differ diff --git a/deps/ofi/lib/prov/libshm-fi.so b/deps/ofi/lib/prov/libshm-fi.so index 73ec980df..9394b7be7 100755 Binary files a/deps/ofi/lib/prov/libshm-fi.so and b/deps/ofi/lib/prov/libshm-fi.so differ diff --git a/deps/ofi/lib/prov/libsockets-fi.so b/deps/ofi/lib/prov/libsockets-fi.so index 83d743b77..7145739c7 100755 Binary files a/deps/ofi/lib/prov/libsockets-fi.so and b/deps/ofi/lib/prov/libsockets-fi.so differ diff --git a/deps/ofi/lib/prov/libtcp-fi.so b/deps/ofi/lib/prov/libtcp-fi.so index 89b2c7a01..6861c2533 100755 Binary files a/deps/ofi/lib/prov/libtcp-fi.so and b/deps/ofi/lib/prov/libtcp-fi.so differ diff --git a/deps/ofi/lib/prov/libverbs-1.1-fi.so b/deps/ofi/lib/prov/libverbs-1.1-fi.so new file mode 100755 index 000000000..14f00726c Binary files /dev/null and b/deps/ofi/lib/prov/libverbs-1.1-fi.so differ diff --git a/deps/ofi/lib/prov/libverbs-1.12-fi.so b/deps/ofi/lib/prov/libverbs-1.12-fi.so new file mode 100755 index 000000000..1998f3ecf Binary files /dev/null and b/deps/ofi/lib/prov/libverbs-1.12-fi.so differ diff --git a/deps/ofi/lib/prov/libverbs-fi.so b/deps/ofi/lib/prov/libverbs-fi.so deleted file mode 100755 index 91c41bce2..000000000 Binary files a/deps/ofi/lib/prov/libverbs-fi.so and /dev/null differ diff --git a/doc/rst/source/advanced-configuration/dmabuf.rst b/doc/rst/source/advanced-configuration/dmabuf.rst index 4201d2704..694892495 100644 --- a/doc/rst/source/advanced-configuration/dmabuf.rst +++ b/doc/rst/source/advanced-configuration/dmabuf.rst @@ -1,12 +1,11 @@ -.. _`here`: https://github.com/ofiwg/libfabric/releases/tag/v1.13.1 +.. _`here`: https://github.com/ofiwg/libfabric/releases/tag/v1.13.2 .. _`documentation`: https://one-api.gitlab-pages.devtools.intel.com/level_zero/core/PROG.html#affinity-mask ===================================== Enabling OFI/verbs dmabuf support ===================================== -|product_short| provides experimental support for device memory transfers using Linux dmabuf, -which is exposed through OFI API for verbs provider. +|product_short| provides experimental support for data transfers between Intel GPU memory and NIC using Linux dmabuf, which is exposed through OFI API for verbs provider. Requirements @@ -17,12 +16,12 @@ Requirements - level-zero-devel package -Limitations -########### +Usage +##### -- Only first tile should be used from each GPU card. - For example, if GPU with 2 tiles is used then set ZE_AFFINITY_MASK=0.0. - More information about GPU selection can be found in level-zero `documentation`_. +|product_short|, OFI and OFI/verbs from |base_tk| support device memory transfers. Refer to `Run instructions`__ for usage. + +If you want to build software components from sources, refer to `Build instructions`__. Build instructions @@ -33,10 +32,10 @@ OFI :: - git clone --single-branch --branch v1.13.1 https://github.com/ofiwg/libfabric.git + git clone --single-branch --branch v1.13.2 https://github.com/ofiwg/libfabric.git cd libfabric ./autogen.sh - ./configure --prefix= --enable-verbs= --enable-ze-dlopen=yes + ./configure --prefix= --enable-verbs= --with-ze= --enable-ze-dlopen=yes make -j install .. note:: @@ -48,17 +47,34 @@ OFI :: - cmake -DCMAKE_INSTALL_PREFIX= -DLIBFABRIC_DIR= -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero -DENABLE_OFI_HMEM=1 .. + cmake -DCMAKE_INSTALL_PREFIX= -DLIBFABRIC_DIR= -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero -DENABLE_OFI_HMEM=1 .. make -j install Run instructions ################ -Run allreduce test with ring algorithm and SYCL USM device buffers. +1. Set the environment. -:: + If |base_tk| is used: + + :: + + source /setvars.sh + + If software components are built from sources: + + :: + + source /env/setvars.sh + export LD_LIBRARY_PATH=/lib:${LD_LIBRARY_PATH} + +2. Run allreduce test with ring algorithm and SYCL USM device buffers: + + :: - source /env/setvars.sh - export LD_LIBRARY_PATH=/lib:${LD_LIBRARY_PATH} - CCL_ATL_TRANSPORT=ofi CCL_ATL_HMEM=1 CCL_ALLREDUCE=ring FI_PROVIDER=verbs mpiexec -n 2 /examples/sycl/sycl_allreduce_usm_test gpu device + export CCL_ATL_TRANSPORT=ofi + export CCL_ATL_HMEM=1 + export CCL_ALLREDUCE=ring + export FI_PROVIDER=verbs + mpiexec -n 2 /examples/sycl/sycl_allreduce_usm_test gpu device diff --git a/doc/rst/source/env-variables.rst b/doc/rst/source/env-variables.rst index 393839863..fa48e4dde 100644 --- a/doc/rst/source/env-variables.rst +++ b/doc/rst/source/env-variables.rst @@ -86,14 +86,12 @@ Available algorithms for each collective operation (````): - Based on ``MPI_Iallreduce`` * - ``rabenseifner`` - Rabenseifner’s algorithm - * - ``starlike`` + * - ``nreduce`` - May be beneficial for imbalanced workloads * - ``ring`` - reduce_scatter + allgather ring. Use ``CCL_RS_CHUNK_COUNT`` and ``CCL_RS_MIN_CHUNK_SIZE`` to control pipelining on reduce_scatter phase. - * - ``ring_rma`` - - reduce_scatter+allgather ring using RMA communications * - ``double_tree`` - Double-tree algorithm * - ``recursive_doubling`` @@ -713,3 +711,32 @@ CCL_MNIC_COUNT Set this environment variable to specify the maximum number of NICs to be selected. The actual number of NICs selected may be smaller due to limitations on transport level or system configuration. + + +CCL_SYCL_OUTPUT_EVENT +##################### +**Syntax** + +:: + + CCL_SYCL_OUTPUT_EVENT= + +**Arguments** + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + :align: left + + * - + - Description + * - ``1`` + - Enable support for SYCL output event. + * - ``0`` + - Disable support for SYCL output event (**default**). + +**Description** + +Set this environment variable to control support for SYCL output event. +Once the support is enabled, you can retrieve SYCL output event from oneCCL event using ``get_native()`` method. +oneCCL event must be associated with oneCCL communication operation. diff --git a/doc/rst/source/introduction/installation.rst b/doc/rst/source/introduction/installation.rst index a3d905dd7..1731f8b3f 100644 --- a/doc/rst/source/introduction/installation.rst +++ b/doc/rst/source/introduction/installation.rst @@ -78,7 +78,7 @@ You can customize CLI-based installation (for example, specify directory, compil :: - cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero + cmake .. -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=dpcpp -DCOMPUTE_BACKEND=dpcpp_level_zero * To specify the **build type**, modify the ``cmake`` command: @@ -104,11 +104,11 @@ There are two ways to set up the environment: .. prompt:: bash - source /setvars.sh + source /env/setvars.sh - Using |product_short| from |base_tk| installed into ```` (``/opt/intel/inteloneapi`` by default): .. prompt:: bash - source /setvars.sh \ No newline at end of file + source /setvars.sh diff --git a/doc/rst/source/introduction/sample.rst b/doc/rst/source/introduction/sample.rst index ed2c4eb0b..d97c10053 100644 --- a/doc/rst/source/introduction/sample.rst +++ b/doc/rst/source/introduction/sample.rst @@ -24,11 +24,11 @@ Build details #. :ref:`Set up ` the library environment. -#. Use ``clang++`` compiler to build the sample: +#. Use ``dpcpp`` compiler to build the sample: :: - clang++ -I${CCL_ROOT}/include -L${CCL_ROOT}/lib/ -lsycl -lccl -o sample sample.cpp + dpcpp -I${CCL_ROOT}/examples/include -I${CCL_ROOT}/include/ -L${CCL_ROOT}/lib/ -lccl -lmpi -o sample sample.cpp Run the sample diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a5186efaf..4e4a99104 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -30,7 +30,6 @@ endif() if (DEFINED ENV{I_MPI_ROOT}) set(I_MPI_ROOT "$ENV{I_MPI_ROOT}") - set(CMAKE_INSTALL_RPATH "${I_MPI_ROOT}/lib/release_mt/") endif() message(STATUS "CCL_ROOT: ${CCL_ROOT}") @@ -52,25 +51,28 @@ if (${CMAKE_VERSION} VERSION_LESS 3.1) endif() #common release/debug compilation settings -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_COMPILER_FLAGS} -Wall -Werror -D_GNU_SOURCE -fvisibility=internal") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${C_COMPILER_FLAGS} -Wall -Wextra -Wno-unused-parameter -Werror -D_GNU_SOURCE -fvisibility=internal") set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} ${C_COMPILER_FLAGS} -O0 -g -DENABLE_DEBUG") set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} ${C_COMPILER_FLAGS} -O3") set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} ${C_COMPILER_FLAGS} -O2 -g") set(CMAKE_C_STANDARD 99) set(CMAKE_C_STANDARD_REQUIRED ON) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMPILER_FLAGS} -Wall -Werror -D_GNU_SOURCE -fvisibility=internal") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMPILER_FLAGS} -Wall -Wextra -Wno-unused-parameter -Werror -D_GNU_SOURCE -fvisibility=internal") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${CXX_COMPILER_FLAGS} -O0 -g -DENABLE_DEBUG") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CXX_COMPILER_FLAGS} -O3") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} ${CXX_COMPILER_FLAGS} -O2 -g") set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) -if (${CMAKE_C_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" OR(${CMAKE_CXX_COMPILER_ID} STREQUAL "IntelLLVM")) +if ("${COMPUTE_BACKEND}" STREQUAL "dpcpp_level_zero") set(CMAKE_CLANG_FLAGS "-fsycl") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -lsycl") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_CLANG_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_CLANG_FLAGS}") +endif() + +if (${CMAKE_C_COMPILER_ID} STREQUAL "Clang" AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" OR(${CMAKE_CXX_COMPILER_ID} STREQUAL "IntelLLVM")) # Use c++17 to be aligned with the compiler set(CMAKE_CXX_STANDARD 17) endif() @@ -98,7 +100,7 @@ endif() include_directories(include) add_subdirectory(cpu) -if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" OR ${CMAKE_CXX_COMPILER_ID} STREQUAL "IntelLLVM") +if ("${COMPUTE_BACKEND}" STREQUAL "dpcpp_level_zero") add_subdirectory(sycl) endif() add_subdirectory(common) diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index 2a6c4199d..4b5738642 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -42,7 +42,7 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC rt) target_link_libraries(${executable} PUBLIC m) target_link_libraries(${executable} PUBLIC dl) - target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release_mt/) + target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release/) target_link_libraries(${executable} PUBLIC mpi) install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/benchmark OPTIONAL) endforeach() diff --git a/examples/benchmark/include/benchmark.hpp b/examples/benchmark/include/benchmark.hpp index aafa9684a..5b98c7dae 100644 --- a/examples/benchmark/include/benchmark.hpp +++ b/examples/benchmark/include/benchmark.hpp @@ -274,12 +274,7 @@ int set_datatypes(std::string option_value, std::list& datatypes) { datatypes.clear(); if (option_value == "all") { - if (is_check_values_enabled(check_values)) { - datatypes = tokenize(ALL_DTYPES_LIST_WITH_CHECK, ','); - } - else { - datatypes = tokenize(ALL_DTYPES_LIST, ','); - } + datatypes = tokenize(ALL_DTYPES_LIST, ','); } else { datatypes = tokenize(option_value, ','); @@ -288,19 +283,12 @@ int set_datatypes(std::string option_value, std::set supported_option_values; for (auto p : dtype_names) { - if ((p.first == ccl::datatype::float16 || p.first == ccl::datatype::bfloat16) && - is_check_values_enabled(check_values)) - continue; supported_option_values.insert(p.second); } for (auto dt : datatypes) { if (check_supported_options(option_name, dt, supported_option_values)) { - if ((dt == dtype_names[ccl::datatype::float16] || - dt == dtype_names[ccl::datatype::bfloat16]) && - is_check_values_enabled(check_values)) { - PRINT("WARN: correctness checking is not implemented for '%s'", dt.c_str()); - } + return -1; } } } @@ -835,7 +823,7 @@ void print_user_options(const user_options_t& options, const ccl::communicator& #endif PRINT_BY_ROOT(comm, - "options:" + "\noptions:" "\n processes: %d" "\n backend: %s" "\n loop: %s" diff --git a/examples/benchmark/include/coll.hpp b/examples/benchmark/include/coll.hpp index 9a8c5d4c8..053b31721 100644 --- a/examples/benchmark/include/coll.hpp +++ b/examples/benchmark/include/coll.hpp @@ -26,6 +26,9 @@ using sycl_buffer_t = cl::sycl::buffer; #define COLL_ROOT (0) +#define BF16_COEF 0.00001 +#define FP16_COEF 0.0001 + struct base_coll; using coll_list_t = std::vector>; @@ -97,6 +100,21 @@ typedef struct bench_init_attr { #endif } bench_init_attr; +template +inline OutDtype get_val(InDtype value) { + return value; +} + +template <> +inline ccl::bfloat16 get_val(float value) { + return fp32_to_bf16(BF16_COEF * value); +} + +template <> +inline ccl::float16 get_val(float value) { + return fp32_to_fp16(FP16_COEF * value); +} + /* base polymorph collective wrapper class */ struct base_coll { base_coll(bench_init_attr init_attr) : init_attr(init_attr) { @@ -116,6 +134,30 @@ struct base_coll { return nullptr; }; +#ifdef CCL_ENABLE_SYCL + template > +#else // CCL_ENABLE_SYCL + template > +#endif // CCL_ENABLE_SYCL + vector_t get_initial_values(size_t elem_count, int fill_value) { + vector_t res(elem_count); + ccl::datatype dt = ccl::native_type_info::type>::dtype; + if (dt == ccl::datatype::bfloat16) { + for (size_t elem_idx = 0; elem_idx < elem_count; elem_idx++) { + res[elem_idx] = fp32_to_bf16(BF16_COEF * fill_value).get_data(); + } + } + else if (dt == ccl::datatype::float16) { + for (size_t elem_idx = 0; elem_idx < elem_count; elem_idx++) { + res[elem_idx] = fp32_to_fp16(FP16_COEF * fill_value).get_data(); + } + } + else { + std::fill(res.begin(), res.end(), fill_value); + } + return res; + } + virtual void prepare(size_t elem_count) { auto& transport = transport_data::instance(); auto& comms = transport.get_comms(); @@ -128,10 +170,6 @@ struct base_coll { } virtual void finalize(size_t elem_count) { - auto dtype = get_dtype(); - if (dtype == ccl::datatype::float16 || dtype == ccl::datatype::bfloat16) - return; - auto& transport = transport_data::instance(); auto& comms = transport.get_comms(); auto streams = transport.get_bench_streams(); diff --git a/examples/benchmark/include/config.hpp b/examples/benchmark/include/config.hpp index fbd981fa7..e2db1e518 100644 --- a/examples/benchmark/include/config.hpp +++ b/examples/benchmark/include/config.hpp @@ -21,8 +21,7 @@ #define ALL_COLLS_LIST "allgatherv,allreduce,alltoall,alltoallv,bcast,reduce,reduce_scatter" -#define ALL_DTYPES_LIST "int8,int32,int64,uint64,float16,float32,float64,bfloat16" -#define ALL_DTYPES_LIST_WITH_CHECK "int8,int32,int64,uint64,float32,float64" +#define ALL_DTYPES_LIST "int8,int32,int64,uint64,float16,float32,float64,bfloat16" #define ALL_REDUCTIONS_LIST "sum,prod,min,max" #define ALL_REDUCTIONS_LIST_WITH_CHECK "sum" diff --git a/examples/benchmark/include/cpu_coll.hpp b/examples/benchmark/include/cpu_coll.hpp index 4287bab01..38f873ee7 100644 --- a/examples/benchmark/include/cpu_coll.hpp +++ b/examples/benchmark/include/cpu_coll.hpp @@ -96,7 +96,7 @@ struct cpu_base_coll : base_coll, protected strategy { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - int local_rank = comm.rank(); + int comm_rank = comm.rank(); size_t send_count = coll_strategy::get_send_multiplier() * elem_count; size_t recv_count = coll_strategy::get_recv_multiplier() * elem_count; @@ -104,13 +104,12 @@ struct cpu_base_coll : base_coll, protected strategy { size_t send_bytes = send_count * base_coll::get_dtype_size(); size_t recv_bytes = recv_count * base_coll::get_dtype_size(); - std::vector fill_vector(send_count); - std::fill(fill_vector.begin(), fill_vector.end(), local_rank); + auto fill_vector = get_initial_values(send_count, comm_rank); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { memcpy(send_bufs[b_idx][rank_idx], fill_vector.data(), send_bytes); if (!base_coll::get_inplace()) { - memset(recv_bufs[b_idx][rank_idx], 0, recv_bytes); + memset(recv_bufs[b_idx][rank_idx], -1, recv_bytes); } } } diff --git a/examples/benchmark/include/lp.hpp b/examples/benchmark/include/lp.hpp new file mode 100644 index 000000000..5d278e521 --- /dev/null +++ b/examples/benchmark/include/lp.hpp @@ -0,0 +1,96 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "oneapi/ccl/lp_types.hpp" + +ccl::float16 fp32_to_fp16(float val) { + uint32_t ans; + uint32_t* val_ptr = (reinterpret_cast(&val)); + uint32_t exp_bits = (*val_ptr & 0x7F800000); + uint32_t significand_bits = (*val_ptr & 0x007FFFFF); + if (exp_bits == 0x00000000) { + ans = (*val_ptr & 0x80000000) >> 16; + } + else if (exp_bits == 0x7F800000) { + if (significand_bits != 0) { + ans = ((*val_ptr & 0x80000000) >> 16) | 0x00007C01; + } + else { + ans = ((*val_ptr & 0x80000000) >> 16) | 0x00007C00; + } + } + else if (exp_bits < 0x38800000) { + ans = 0xFC00; + } + else if (exp_bits > 0x47000000) { + ans = 0x7C00; + } + else { + ans = ((*val_ptr & 0x80000000) >> 16) | ((((*val_ptr & 0x7F800000) >> 23) - 112) << 10) | + ((*val_ptr & 0x007FFFFF) >> 13); + } + return ccl::float16(ans); +} + +float fp16_to_fp32(ccl::float16 val) { + uint16_t val_data = val.get_data(); + float ans = 0.0f; + uint32_t ans_bits = 0; + uint32_t exp_bits = val_data & 0x7C00; + uint32_t significand_bits = val_data & 0x03FF; + if (exp_bits == 0x7C00) { + ans_bits = ((val_data & 0x8000) << 16) | 0x7F800000 | (significand_bits << 13); + } + else if (exp_bits == 0x0000) { + if (significand_bits != 0x00000000) { + ans_bits = ((val_data & 0x8000) << 16); + } + else { + ans_bits = ((val_data & 0x8000) << 16) | (significand_bits << 13); + } + } + else { + ans_bits = + ((val_data & 0x8000) << 16) | ((exp_bits + 0x1C000) << 13) | (significand_bits << 13); + } + std::memcpy(reinterpret_cast(&ans), reinterpret_cast(&ans_bits), 4); + return ans; +} + +ccl::bfloat16 fp32_to_bf16(float val) { + // Truncate + uint16_t int_val = 0; + memcpy(&int_val, reinterpret_cast(&val) + 2, 2); + return ccl::bfloat16(int_val); +} + +float bf16_to_fp32(ccl::bfloat16 val) { + float ret = 0; + uint32_t temp = static_cast(val.get_data()) << 16; + memcpy(&ret, &temp, sizeof(temp)); + return ret; +} + +std::ostream& operator<<(std::ostream& out, const ccl::float16& v) { + out << fp16_to_fp32(v) << "|" << v.get_data(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const ccl::bfloat16& v) { + out << bf16_to_fp32(v) << "|" << v.get_data(); + return out; +} diff --git a/examples/benchmark/include/sycl_coll.hpp b/examples/benchmark/include/sycl_coll.hpp index a605af700..26cec4117 100644 --- a/examples/benchmark/include/sycl_coll.hpp +++ b/examples/benchmark/include/sycl_coll.hpp @@ -24,7 +24,6 @@ #include "sycl_base.hpp" /* from examples/include */ #ifdef CCL_ENABLE_SYCL - #include using namespace sycl; @@ -159,7 +158,7 @@ struct sycl_base_coll : base_coll, private strategy { size_t send_bytes = send_count * base_coll::get_dtype_size(); size_t recv_bytes = recv_count * base_coll::get_dtype_size(); - std::fill(host_send_buf.begin(), host_send_buf.end(), comm_rank); + host_send_buf = get_initial_values(send_count, comm_rank); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { @@ -168,7 +167,7 @@ struct sycl_base_coll : base_coll, private strategy { .wait(); if (!base_coll::get_inplace()) { - stream.get_native().memset(recv_bufs[b_idx][rank_idx], 0, recv_bytes).wait(); + stream.get_native().memset(recv_bufs[b_idx][rank_idx], -1, recv_bytes).wait(); } } else { @@ -188,7 +187,7 @@ struct sycl_base_coll : base_coll, private strategy { (static_cast*>(recv_bufs[b_idx][rank_idx])); auto recv_buf_acc = recv_buf->template get_access(h, recv_count); - h.fill(recv_buf_acc, static_cast(0)); + h.fill(recv_buf_acc, static_cast(-1)); }) .wait(); } @@ -200,8 +199,8 @@ struct sycl_base_coll : base_coll, private strategy { } /* used on fill/check phases */ - std::vector host_send_buf; - std::vector host_recv_buf; + aligned_vector host_send_buf; + aligned_vector host_recv_buf; private: std::vector> allocators; diff --git a/examples/benchmark/include/types.hpp b/examples/benchmark/include/types.hpp index 2c12cdc67..6a90d037c 100644 --- a/examples/benchmark/include/types.hpp +++ b/examples/benchmark/include/types.hpp @@ -16,6 +16,7 @@ #pragma once #include "oneapi/ccl.hpp" +#include "lp.hpp" #define PRINT(fmt, ...) printf(fmt "\n", ##__VA_ARGS__); @@ -174,16 +175,6 @@ typedef struct user_options_t { } } user_options_t; -std::ostream& operator<<(std::ostream& out, const ccl::bfloat16& v) { - out << v.get_data(); - return out; -} - -std::ostream& operator<<(std::ostream& out, const ccl::float16& v) { - out << v.get_data(); - return out; -} - template ccl::datatype get_ccl_dtype() { return ccl::native_type_info::type>::dtype; diff --git a/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp b/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp index 795e3259a..cef8d7288 100644 --- a/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp +++ b/examples/benchmark/src/allgatherv/cpu_allgatherv_coll.hpp @@ -30,7 +30,7 @@ struct cpu_allgatherv_coll : cpu_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { @@ -44,7 +44,7 @@ struct cpu_allgatherv_coll : cpu_base_coll { } for (int idx = 0; idx < comm.size(); idx++) { - Dtype rbuf_expected = idx; + Dtype rbuf_expected = get_val(static_cast(idx)); for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { value = ((Dtype*)recv_bufs[b_idx][rank_idx])[idx * elem_count + e_idx]; if (value != rbuf_expected) { diff --git a/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp b/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp index 1d99ac63e..ddf40bd42 100644 --- a/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp +++ b/examples/benchmark/src/allgatherv/sycl_allgatherv_coll.hpp @@ -35,7 +35,7 @@ struct sycl_allgatherv_coll : sycl_base_coll { ccl::stream& stream, size_t rank_idx) override { int comm_size = comm.size(); - Dtype sbuf_expected = comm.rank(); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); size_t send_bytes = elem_count * base_coll::get_dtype_size(); size_t recv_bytes = comm_size * elem_count * base_coll::get_dtype_size(); @@ -78,7 +78,7 @@ struct sycl_allgatherv_coll : sycl_base_coll { } for (int idx = 0; idx < comm.size(); idx++) { - Dtype rbuf_expected = idx; + Dtype rbuf_expected = get_val(static_cast(idx)); for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { value = host_recv_buf[idx * elem_count + e_idx]; if (value != rbuf_expected) { diff --git a/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp b/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp index a0d289aef..536ca34c5 100644 --- a/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp +++ b/examples/benchmark/src/allreduce/cpu_allreduce_coll.hpp @@ -30,9 +30,10 @@ struct cpu_allreduce_coll : cpu_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); /* TODO: handle PROD, MIN, MAX */ - Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); + Dtype rbuf_expected = get_val((comm.size() - 1) * ((float)comm.size() / 2)); + Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { diff --git a/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp b/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp index cd79face6..e0f0fa5d8 100644 --- a/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp +++ b/examples/benchmark/src/allreduce/sycl_allreduce_coll.hpp @@ -34,8 +34,8 @@ struct sycl_allreduce_coll : sycl_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); - Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); + Dtype rbuf_expected = get_val((comm.size() - 1) * ((float)comm.size() / 2)); size_t send_bytes = elem_count * base_coll::get_dtype_size(); size_t recv_bytes = elem_count * base_coll::get_dtype_size(); @@ -66,7 +66,6 @@ struct sycl_allreduce_coll : sycl_base_coll { } Dtype value; - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { value = host_send_buf[e_idx]; if (!base_coll::get_inplace() && (value != sbuf_expected)) { diff --git a/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp b/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp index 6e4458ca2..836ef7893 100644 --- a/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp +++ b/examples/benchmark/src/alltoall/cpu_alltoall_coll.hpp @@ -30,15 +30,14 @@ struct cpu_alltoall_coll : cpu_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); - Dtype rbuf_expected; + Dtype sbuf_expected = get_val(static_cast(comm.rank())); Dtype value; int comm_size = comm.size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; - rbuf_expected = e_idx / elem_count; + Dtype rbuf_expected = get_val(static_cast(e_idx / elem_count)); if (value != sbuf_expected) { std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " << rank_idx << ", elem_idx " << e_idx << ", expected " diff --git a/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp b/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp index 5d51be30e..96df23d27 100644 --- a/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp +++ b/examples/benchmark/src/alltoall/sycl_alltoall_coll.hpp @@ -34,7 +34,7 @@ struct sycl_alltoall_coll : sycl_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); int comm_size = comm.size(); size_t send_bytes = comm_size * elem_count * base_coll::get_dtype_size(); @@ -69,7 +69,7 @@ struct sycl_alltoall_coll : sycl_base_coll { for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { value = host_send_buf[e_idx]; - Dtype rbuf_expected = e_idx / elem_count; + Dtype rbuf_expected = get_val(static_cast(e_idx / elem_count)); if (value != sbuf_expected) { std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " << rank_idx << ", elem_idx " << e_idx << ", expected " diff --git a/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp b/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp index 58eea5922..87f21b56a 100644 --- a/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp +++ b/examples/benchmark/src/alltoallv/cpu_alltoallv_coll.hpp @@ -30,14 +30,13 @@ struct cpu_alltoallv_coll : cpu_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); - Dtype rbuf_expected; + Dtype sbuf_expected = get_val(static_cast(comm.rank())); Dtype value; int comm_size = comm.size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { value = ((Dtype*)send_bufs[b_idx][rank_idx])[e_idx]; - rbuf_expected = e_idx / elem_count; + Dtype rbuf_expected = get_val(static_cast(e_idx / elem_count)); if (value != sbuf_expected) { std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " << rank_idx << ", elem_idx " << e_idx << ", expected " diff --git a/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp b/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp index 4e1a31af2..6db2d0160 100644 --- a/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp +++ b/examples/benchmark/src/alltoallv/sycl_alltoallv_coll.hpp @@ -34,7 +34,7 @@ struct sycl_alltoallv_coll : sycl_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); int comm_size = comm.size(); size_t send_bytes = comm_size * elem_count * base_coll::get_dtype_size(); @@ -69,7 +69,7 @@ struct sycl_alltoallv_coll : sycl_base_coll { for (size_t e_idx = 0; e_idx < elem_count * comm_size; e_idx++) { value = host_send_buf[e_idx]; - Dtype rbuf_expected = e_idx / elem_count; + Dtype rbuf_expected = get_val(static_cast(e_idx / elem_count)); if (value != sbuf_expected) { std::cout << this->name() << " send_bufs: buf_idx " << b_idx << ", rank_idx " << rank_idx << ", elem_idx " << e_idx << ", expected " diff --git a/examples/benchmark/src/bcast/cpu_bcast_coll.hpp b/examples/benchmark/src/bcast/cpu_bcast_coll.hpp index dfd23566e..0d64d0cbb 100644 --- a/examples/benchmark/src/bcast/cpu_bcast_coll.hpp +++ b/examples/benchmark/src/bcast/cpu_bcast_coll.hpp @@ -31,10 +31,13 @@ struct cpu_bcast_coll : cpu_base_coll { size_t rank_idx) override { for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { - if (comm.rank() == COLL_ROOT) - ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx] = b_idx; - else + if (comm.rank() == COLL_ROOT) { + ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx] = + get_val(static_cast(b_idx)); + } + else { ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx] = 0; + } } } } @@ -47,9 +50,10 @@ struct cpu_bcast_coll : cpu_base_coll { for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { value = ((Dtype*)recv_bufs[b_idx][rank_idx])[e_idx]; - if (cast_to_size_t(value) != b_idx) { + Dtype expected = get_val(static_cast(b_idx)); + if (value != expected) { std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " - << rank_idx << ", elem_idx " << e_idx << ", expected " << b_idx + << rank_idx << ", elem_idx " << e_idx << ", expected " << expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } diff --git a/examples/benchmark/src/bcast/sycl_bcast_coll.hpp b/examples/benchmark/src/bcast/sycl_bcast_coll.hpp index f0a06af50..cd9742ab9 100644 --- a/examples/benchmark/src/bcast/sycl_bcast_coll.hpp +++ b/examples/benchmark/src/bcast/sycl_bcast_coll.hpp @@ -38,7 +38,7 @@ struct sycl_bcast_coll : sycl_base_coll { size_t bytes = count * base_coll::get_dtype_size(); for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { - std::fill(host_recv_buf.begin(), host_recv_buf.end(), b_idx); + host_recv_buf = base_coll::get_initial_values(count, static_cast(b_idx)); if (base_coll::get_sycl_mem_type() == SYCL_MEM_USM) { if (comm_rank == COLL_ROOT) @@ -88,12 +88,12 @@ struct sycl_bcast_coll : sycl_base_coll { } Dtype value; - for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { value = host_recv_buf[e_idx]; - if (value != static_cast(b_idx)) { // comparison float16 with size_t ?? + Dtype expected = get_val(static_cast(b_idx)); + if (value != expected) { std::cout << this->name() << " recv_bufs: buf_idx " << b_idx << ", rank_idx " - << rank_idx << ", elem_idx " << e_idx << ", expected " << (Dtype)b_idx + << rank_idx << ", elem_idx " << e_idx << ", expected " << expected << ", got " << value << std::endl; ASSERT(0, "unexpected value"); } diff --git a/examples/benchmark/src/reduce/cpu_reduce_coll.hpp b/examples/benchmark/src/reduce/cpu_reduce_coll.hpp index 0a0c9d445..197dc2b94 100644 --- a/examples/benchmark/src/reduce/cpu_reduce_coll.hpp +++ b/examples/benchmark/src/reduce/cpu_reduce_coll.hpp @@ -30,8 +30,8 @@ struct cpu_reduce_coll : cpu_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); - Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); + Dtype rbuf_expected = get_val((comm.size() - 1) * ((float)comm.size() / 2)); Dtype value; for (size_t b_idx = 0; b_idx < base_coll::get_buf_count(); b_idx++) { for (size_t e_idx = 0; e_idx < elem_count; e_idx++) { diff --git a/examples/benchmark/src/reduce/sycl_reduce_coll.hpp b/examples/benchmark/src/reduce/sycl_reduce_coll.hpp index b9ac0ce95..47059af9e 100644 --- a/examples/benchmark/src/reduce/sycl_reduce_coll.hpp +++ b/examples/benchmark/src/reduce/sycl_reduce_coll.hpp @@ -34,8 +34,8 @@ struct sycl_reduce_coll : sycl_base_coll { ccl::communicator& comm, ccl::stream& stream, size_t rank_idx) override { - Dtype sbuf_expected = comm.rank(); - Dtype rbuf_expected = (comm.size() - 1) * ((float)comm.size() / 2); + Dtype sbuf_expected = get_val(static_cast(comm.rank())); + Dtype rbuf_expected = get_val((comm.size() - 1) * ((float)comm.size() / 2)); int comm_rank = comm.rank(); diff --git a/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp b/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp index f9bf0107a..ce1121c6b 100644 --- a/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp +++ b/examples/benchmark/src/reduce_scatter/cpu_reduce_scatter_coll.hpp @@ -30,8 +30,8 @@ struct cpu_reduce_scatter_coll : cpu_base_coll(static_cast(comm.rank())); + Dtype rbuf_expected = get_val((comm.size() - 1) * ((float)comm.size() / 2)); Dtype value; size_t recv_elem_count = elem_count / comm.size(); diff --git a/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp b/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp index 0013ec1cb..d57b7126a 100644 --- a/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp +++ b/examples/benchmark/src/reduce_scatter/sycl_reduce_scatter_coll.hpp @@ -34,8 +34,8 @@ struct sycl_reduce_scatter_coll : sycl_base_coll(static_cast(comm.rank())); + Dtype rbuf_expected = get_val((comm.size() - 1) * ((float)comm.size() / 2)); size_t recv_elem_count = elem_count / comm.size(); diff --git a/examples/common/CMakeLists.txt b/examples/common/CMakeLists.txt index 296a83adb..9edb0dd4e 100644 --- a/examples/common/CMakeLists.txt +++ b/examples/common/CMakeLists.txt @@ -27,7 +27,7 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC rt) target_link_libraries(${executable} PUBLIC m) target_link_libraries(${executable} PUBLIC dl) - target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release_mt/) + target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release/) target_link_libraries(${executable} PUBLIC mpi) install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/common OPTIONAL) endforeach() diff --git a/examples/cpu/CMakeLists.txt b/examples/cpu/CMakeLists.txt index 77a2c0342..ac50a4e05 100644 --- a/examples/cpu/CMakeLists.txt +++ b/examples/cpu/CMakeLists.txt @@ -27,7 +27,7 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC dl) target_link_libraries(${executable} PUBLIC pthread) target_link_libraries(${executable} PUBLIC stdc++) - target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release_mt/) + target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release/) target_link_libraries(${executable} PUBLIC mpi) install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/cpu OPTIONAL) endforeach() diff --git a/examples/cpu/communicator.cpp b/examples/cpu/communicator.cpp index 44b7dc291..6d3b82cc5 100644 --- a/examples/cpu/communicator.cpp +++ b/examples/cpu/communicator.cpp @@ -184,11 +184,6 @@ void check_comm_split_identical_color(ccl::communicator& comm) { } int main() { - /** - * The example only works with CCL_ATL_TRANSPORT=ofi - */ - setenv("CCL_ATL_TRANSPORT", "ofi", 0); - ccl::init(); int mpi_size, mpi_rank; diff --git a/examples/cpu/external_kvs.cpp b/examples/cpu/external_kvs.cpp index c22ca77f3..55a94696f 100644 --- a/examples/cpu/external_kvs.cpp +++ b/examples/cpu/external_kvs.cpp @@ -88,7 +88,7 @@ int main() { kvs = ccl::create_kvs(main_addr); } - auto ext_kvs = std::make_shared(kvs); + auto ext_kvs = std::shared_ptr(new external_kvs(kvs)); auto comm = ccl::create_communicator(size, rank, ext_kvs); auto attr = ccl::create_operation_attr(); diff --git a/examples/external_launcher/run_binary.sh b/examples/external_launcher/run_binary.sh index 56e430d6c..732c15748 100755 --- a/examples/external_launcher/run_binary.sh +++ b/examples/external_launcher/run_binary.sh @@ -147,10 +147,10 @@ function run() elif [[ $CCL_VARS == *"vars.sh"* ]]; then echo "Use oneAPI CCL variables script" - source ${MPI_VARS} -i_mpi_library_kind=release_mt + source ${MPI_VARS} fi - export CCL_CONFIGURATION="cpu_icc" + export CCL_CONFIGURATION="cpu" source ${CCL_VARS} --ccl-configuration="${CCL_CONFIGURATION}" eval `echo $binary_env $binary_path $binary_arg ;` &> $LOG_FILE diff --git a/examples/include/base_utils.hpp b/examples/include/base_utils.hpp index 5dd68a1dc..ca6c8e370 100644 --- a/examples/include/base_utils.hpp +++ b/examples/include/base_utils.hpp @@ -16,10 +16,13 @@ #pragma once #include +#include #include +#include #include #include #include +#include template struct get_tuple_elem_index { @@ -109,6 +112,38 @@ void ccl_tuple_for_each_indexed(functor f, const FunctionArgs&... args) { f, is_tuple_finished_t{}, args...); } +template +struct aligned_allocator { + using value_type = T; + using pointer = T*; + + template + struct rebind { + using other = aligned_allocator; + }; + + aligned_allocator() = default; + ~aligned_allocator() = default; + + template + constexpr aligned_allocator(const aligned_allocator&) noexcept {} + + inline pointer allocate(size_t n) { + void* ptr = aligned_alloc(align, sizeof(value_type) * n); + if (!ptr) { + throw std::bad_alloc(); + } + return reinterpret_cast(ptr); + } + + inline void deallocate(pointer ptr, size_t size) noexcept { + free(ptr); + } +}; + +template +using aligned_vector = std::vector>; + namespace utils { template diff --git a/examples/include/bf16.hpp b/examples/include/bf16.hpp index 3b8039d24..c05bba421 100644 --- a/examples/include/bf16.hpp +++ b/examples/include/bf16.hpp @@ -45,8 +45,8 @@ int is_bf16_enabled() { __asm__ __volatile__("cpuid" : "=a"(reg[0]), "=b"(reg[1]), "=c"(reg[2]), "=d"(reg[3]) : "a"(7), "c"(0)); - is_avx512f_enabled = - ((reg[1] & (1 << 16)) >> 16) & ((reg[1] & (1 << 30)) >> 30) & ((reg[1] & (1 << 31)) >> 31); + is_avx512f_enabled = ((reg[1] & (1u << 16)) >> 16) & ((reg[1] & (1u << 30)) >> 30) & + ((reg[1] & (1u << 31)) >> 31); return (is_avx512f_enabled) ? 1 : 0; #else diff --git a/examples/include/sycl_base.hpp b/examples/include/sycl_base.hpp index f944018f2..12948129f 100644 --- a/examples/include/sycl_base.hpp +++ b/examples/include/sycl_base.hpp @@ -74,22 +74,19 @@ inline bool check_sycl_usm(queue& q, usm::alloc alloc_type) { } inline std::string get_preferred_gpu_platform_name() { - std::string filter; std::string result; - if (getenv("SYCL_DEVICE_FILTER") == nullptr) { - filter = "level-zero"; - } - else if (getenv("SYCL_DEVICE_FILTER") != nullptr) { - if (std::strstr(getenv("SYCL_DEVICE_FILTER"), "level_zero") != NULL) { + std::string filter = "level-zero"; + char* env = getenv("SYCL_DEVICE_FILTER"); + if (env) { + if (std::strstr(env, "level_zero")) { filter = "level-zero"; } - else if (std::strstr(getenv("SYCL_DEVICE_FILTER"), "opencl") != NULL) { + else if (std::strstr(env, "opencl")) { filter = "opencl"; } else { - throw std::runtime_error("invalid device filter: " + - std::string(getenv("SYCL_DEVICE_FILTER"))); + throw std::runtime_error("invalid device filter: " + std::string(env)); } } @@ -131,31 +128,27 @@ inline std::string get_preferred_gpu_platform_name() { } inline std::vector create_sycl_gpu_devices() { - constexpr char dev_prefix[] = "-- "; - constexpr char sub_dev_prefix[] = "---- "; + constexpr char prefix[] = "-- "; std::vector result; auto plaform_list = sycl::platform::get_platforms(); auto preferred_platform_name = get_preferred_gpu_platform_name(); std::stringstream ss; - ss << "preferred platform: [" << preferred_platform_name << "]\n"; + std::stringstream ss_warn; for (const auto& platform : plaform_list) { auto platform_name = platform.get_info(); - - if (platform_name.compare(preferred_platform_name) != 0) + if (platform_name.compare(preferred_platform_name) != 0) { continue; - - ss << "platform: [" << platform_name << "]\n"; + } auto device_list = platform.get_devices(); - for (const auto& device : device_list) { auto device_name = device.get_info(); if (!device.is_gpu()) { - ss << dev_prefix << "device [" << device_name << "] is not GPU, skipping\n"; + ss_warn << prefix << "device [" << device_name << "] is not GPU, skipping\n"; continue; } @@ -165,9 +158,9 @@ inline std::vector create_sycl_gpu_devices() { part_props.end(), info::partition_property::partition_by_affinity_domain) == part_props.end()) { - ss << dev_prefix << "device [" << device_name - << "] does not support partition by affinity domain" - << ", use root device\n"; + ss_warn << prefix << "device [" << device_name + << "] does not support partition by affinity domain" + << ", use root device\n"; result.push_back(device); continue; } @@ -179,37 +172,32 @@ inline std::vector create_sycl_gpu_devices() { part_affinity_domains.end(), info::partition_affinity_domain::next_partitionable) == part_affinity_domains.end()) { - ss << dev_prefix << "device [" << device_name - << "] does not support next_partitionable affinity domain" - << ", use root device\n"; + ss_warn << prefix << "device [" << device_name + << "] does not support next_partitionable affinity domain" + << ", use root device\n"; result.push_back(device); continue; } - ss << dev_prefix << "device [" << device_name << "] should provide " - << device.template get_info() - << " sub-devices\n"; - auto sub_devices = device.create_sub_devices( info::partition_affinity_domain::next_partitionable); + size_t sub_devices_max = + device.template get_info(); + if (sub_devices.size() != sub_devices_max) { + ss_warn << prefix << "device [" << device_name << "] expected " << sub_devices_max + << " sub-devices, but got " << sub_devices.size(); + } + if (sub_devices.empty()) { - /* TODO: remove when SYCL/L0 sub-devices will be supported */ - ss << dev_prefix << "device [" << device_name << "] does not provide sub-devices" - << ", use root device\n"; + ss_warn << prefix << "device [" << device_name << "] does not provide sub-devices" + << ", use root device\n"; result.push_back(device); continue; } - ss << dev_prefix << "device [" << device_name << "] provides " << sub_devices.size() - << " sub-devices\n"; result.insert(result.end(), sub_devices.begin(), sub_devices.end()); - - for (size_t idx = 0; idx < sub_devices.size(); idx++) { - ss << sub_dev_prefix << "sub-device " << idx << ": [" - << sub_devices[idx].get_info() << "]\n"; - } } } @@ -217,7 +205,9 @@ inline std::vector create_sycl_gpu_devices() { throw std::runtime_error("no GPU devices found"); } - ss << "found: " << result.size() << " GPU device(s)\n"; + ss << "preferred platform: " << preferred_platform_name << ", found: " << result.size() + << " GPU device(s)\n"; + ss << ss_warn.str(); printf("%s", ss.str().c_str()); return result; @@ -442,95 +432,3 @@ struct buf_allocator { queue q; set memory_storage; }; - -template -struct usm_polymorphic_allocator { - using native_type = data_native_type; - using allocator_types = tuple...>; - using integer_usm_type = typename underlying_type::type; - using self_t = usm_polymorphic_allocator; - - usm_polymorphic_allocator(queue& q) - : allocators{ make_tuple(usm_allocator(q)...) } {} - - ~usm_polymorphic_allocator() { - for (auto& v : memory_storage) { - data_native_type* mem = v.first; - deallocate(mem, v.second.size, v.second.type); - } - } - -private: - struct alloc_info { - size_t size; - usm::alloc type; - }; - map memory_storage; - - struct alloc_impl { - alloc_impl(native_type** out_ptr, size_t count, usm::alloc type, self_t* parent) - : out_usm_memory_pointer(out_ptr), - size(count), - alloc_index(0), - requested_alloc_type(type), - owner(parent) {} - - template - void operator()(specific_allocator& al) { - if (alloc_index++ == static_cast(requested_alloc_type)) { - *out_usm_memory_pointer = al.allocate(size); - - alloc_info info{ size, requested_alloc_type }; - owner->memory_storage.emplace(*out_usm_memory_pointer, info); - } - } - native_type** out_usm_memory_pointer; - size_t size{}; - int alloc_index{}; - usm::alloc requested_alloc_type; - self_t* owner; - }; - - struct dealloc_impl { - dealloc_impl(native_type** in_ptr, size_t count, usm::alloc type, self_t* parent) - : in_usm_memory_pointer(in_ptr), - size(count), - alloc_index(0), - requested_alloc_type(type), - owner(parent) {} - - template - void operator()(specific_allocator& al) { - if (alloc_index++ == static_cast(requested_alloc_type)) { - auto it = owner->memory_storage.find(*in_usm_memory_pointer); - if (it == owner->memory_storage.end()) { - throw std::runtime_error(string(__PRETTY_FUNCTION__) + - " - not owns memory object"); - } - - al.deallocate(*in_usm_memory_pointer, size); - *in_usm_memory_pointer = nullptr; - - owner->memory_storage.erase(it); - } - } - native_type** in_usm_memory_pointer; - size_t size; - int alloc_index; - usm::alloc requested_alloc_type; - self_t* owner; - }; - -public: - allocator_types allocators; - - native_type* allocate(size_t size, usm::alloc type) { - native_type* ret = nullptr; - ccl_tuple_for_each(allocators, alloc_impl{ &ret, size, type, this }); - return ret; - } - - void deallocate(native_type* in_ptr, size_t size, usm::alloc type) { - ccl_tuple_for_each(allocators, dealloc_impl{ &in_ptr, size, type, this }); - } -}; diff --git a/examples/sycl/CMakeLists.txt b/examples/sycl/CMakeLists.txt index ff4fc8b5b..7426280bc 100644 --- a/examples/sycl/CMakeLists.txt +++ b/examples/sycl/CMakeLists.txt @@ -27,7 +27,7 @@ foreach(src ${sources}) target_link_libraries(${executable} PUBLIC rt) target_link_libraries(${executable} PUBLIC m) target_link_libraries(${executable} PRIVATE ccl) - target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release_mt/) + target_link_libraries(${executable} PUBLIC -L${I_MPI_ROOT}/lib/release/) target_link_libraries(${executable} PUBLIC mpi) target_link_libraries(${executable} PRIVATE ${COMPUTE_BACKEND_TARGET_NAME}) install(TARGETS ${executable} RUNTIME DESTINATION ${CCL_INSTALL_EXAMPLES}/sycl OPTIONAL) diff --git a/examples/sycl/sycl_allgatherv_inplace_usm_test.cpp b/examples/sycl/sycl_allgatherv_inplace_usm_test.cpp new file mode 100644 index 000000000..856c675d0 --- /dev/null +++ b/examples/sycl/sycl_allgatherv_inplace_usm_test.cpp @@ -0,0 +1,128 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sycl_base.hpp" + +using namespace std; +using namespace sycl; + +int main(int argc, char *argv[]) { + const size_t count = 10 * 1024 * 1024; + + int size = 0; + int rank = 0; + + ccl::init(); + + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + + queue q; + if (!create_sycl_queue(argc, argv, rank, q)) { + return -1; + } + + buf_allocator allocator(q); + + auto usm_alloc_type = usm::alloc::shared; + if (argc > 2) { + usm_alloc_type = usm_alloc_type_from_string(argv[2]); + } + + if (!check_sycl_usm(q, usm_alloc_type)) { + return -1; + } + + /* create kvs */ + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + /* create communicator */ + auto dev = ccl::create_device(q.get_device()); + auto ctx = ccl::create_context(q.get_context()); + auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs); + + /* create stream */ + auto stream = ccl::create_stream(q); + + /* create buffers */ + auto recv_buf = allocator.allocate(count * size, usm_alloc_type); + + buffer expected_buf(count * size); + buffer check_buf(count * size); + vector recv_counts(size, count); + + /* open buffers and modify them on the device side */ + auto e = q.submit([&](auto &h) { + accessor expected_buf_acc(expected_buf, h, write_only); + h.parallel_for(count, [=](auto id) { + recv_buf[rank * count + id] = rank + 1; + for (int i = 0; i < size; i++) { + expected_buf_acc[i * count + id] = i + 1; + } + }); + }); + + /* do not wait completion of kernel and provide it as dependency for operation */ + vector deps; + deps.push_back(ccl::create_event(e)); + + /* invoke allgatherv */ + auto attr = ccl::create_operation_attr(); + ccl::allgatherv(recv_buf, count, recv_buf, recv_counts, comm, stream, attr, deps).wait(); + + /* open recv_buf and check its correctness on the device side */ + q.submit([&](auto &h) { + accessor expected_buf_acc(expected_buf, h, read_only); + accessor check_buf_acc(check_buf, h, write_only); + h.parallel_for(size * count, [=](auto id) { + if (recv_buf[id] != expected_buf_acc[id]) { + check_buf_acc[id] = -1; + } + }); + }); + + if (!handle_exception(q)) + return -1; + + /* print out the result of the test on the host side */ + { + host_accessor check_buf_acc(check_buf, read_only); + size_t i; + for (i = 0; i < size * count; i++) { + if (check_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; + } + } + if (i == size * count) { + cout << "PASSED\n"; + } + } + + return 0; +} diff --git a/examples/sycl/sycl_allgatherv_test.cpp b/examples/sycl/sycl_allgatherv_test.cpp index e176b240b..bb76dcddb 100644 --- a/examples/sycl/sycl_allgatherv_test.cpp +++ b/examples/sycl/sycl_allgatherv_test.cpp @@ -95,7 +95,7 @@ int main(int argc, char *argv[]) { if (!handle_exception(q)) return -1; - /* invoke allagtherv */ + /* invoke allgatherv */ ccl::allgatherv(send_buf, count, recv_buf, recv_counts, comm, stream).wait(); /* open recv_buf and check its correctness on the device side */ diff --git a/examples/sycl/sycl_allgatherv_usm_test.cpp b/examples/sycl/sycl_allgatherv_usm_test.cpp index 895bbd31b..59c96dff3 100644 --- a/examples/sycl/sycl_allgatherv_usm_test.cpp +++ b/examples/sycl/sycl_allgatherv_usm_test.cpp @@ -82,7 +82,6 @@ int main(int argc, char *argv[]) { accessor expected_buf_acc(expected_buf, h, write_only); h.parallel_for(count, [=](auto id) { send_buf[id] = rank + 1; - recv_buf[id] = -1; for (int i = 0; i < size; i++) { expected_buf_acc[i * count + id] = i + 1; } @@ -93,7 +92,7 @@ int main(int argc, char *argv[]) { vector deps; deps.push_back(ccl::create_event(e)); - /* invoke allagtherv */ + /* invoke allgatherv */ auto attr = ccl::create_operation_attr(); ccl::allgatherv(send_buf, count, recv_buf, recv_counts, comm, stream, attr, deps).wait(); @@ -105,6 +104,9 @@ int main(int argc, char *argv[]) { if (recv_buf[id] != expected_buf_acc[id]) { check_buf_acc[id] = -1; } + else { + check_buf_acc[id] = 0; + } }); }); diff --git a/examples/sycl/sycl_allreduce_inplace_usm_test.cpp b/examples/sycl/sycl_allreduce_inplace_usm_test.cpp index ab2de50d6..6624b75fd 100644 --- a/examples/sycl/sycl_allreduce_inplace_usm_test.cpp +++ b/examples/sycl/sycl_allreduce_inplace_usm_test.cpp @@ -87,7 +87,7 @@ int main(int argc, char *argv[]) { auto attr = ccl::create_operation_attr(); ccl::allreduce(buf, buf, count, ccl::reduction::sum, comm, stream, attr, deps).wait(); - /* open recv_buf and check its correctness on the device side */ + /* open buf and check its correctness on the device side */ buffer check_buf(count); q.submit([&](auto &h) { accessor check_buf_acc(check_buf, h, write_only); diff --git a/examples/sycl/sycl_allreduce_usm_test.cpp b/examples/sycl/sycl_allreduce_usm_test.cpp index 26f2ce4f4..9b27b5759 100644 --- a/examples/sycl/sycl_allreduce_usm_test.cpp +++ b/examples/sycl/sycl_allreduce_usm_test.cpp @@ -18,7 +18,7 @@ using namespace std; using namespace sycl; -int main(int argc, char *argv[]) { +int main(int argc, char* argv[]) { const size_t count = 10 * 1024 * 1024; int size = 0; @@ -54,10 +54,10 @@ int main(int argc, char *argv[]) { if (rank == 0) { kvs = ccl::create_main_kvs(); main_addr = kvs->get_address(); - MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); } else { - MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); kvs = ccl::create_kvs(main_addr); } @@ -74,13 +74,18 @@ int main(int argc, char *argv[]) { auto recv_buf = allocator.allocate(count, usm_alloc_type); /* open buffers and modify them on the device side */ - auto e = q.submit([&](auto &h) { + auto e = q.submit([&](auto& h) { h.parallel_for(count, [=](auto id) { - send_buf[id] = rank + 1; + send_buf[id] = rank + id + 1; recv_buf[id] = -1; }); }); + int check_sum = 0; + for (int i = 1; i <= size; ++i) { + check_sum += i; + } + /* do not wait completion of kernel and provide it as dependency for operation */ vector deps; deps.push_back(ccl::create_event(e)); @@ -91,10 +96,10 @@ int main(int argc, char *argv[]) { /* open recv_buf and check its correctness on the device side */ buffer check_buf(count); - q.submit([&](auto &h) { + q.submit([&](auto& h) { accessor check_buf_acc(check_buf, h, write_only); h.parallel_for(count, [=](auto id) { - if (recv_buf[id] != size * (size + 1) / 2) { + if (recv_buf[id] != static_cast(check_sum + size * id)) { check_buf_acc[id] = -1; } }); diff --git a/examples/sycl/sycl_alltoallv_test.cpp b/examples/sycl/sycl_alltoallv_test.cpp index 2905235bc..be447e15f 100644 --- a/examples/sycl/sycl_alltoallv_test.cpp +++ b/examples/sycl/sycl_alltoallv_test.cpp @@ -89,7 +89,7 @@ int main(int argc, char *argv[]) { if (!handle_exception(q)) return -1; - /* invoke alltoall */ + /* invoke alltoallv */ ccl::alltoallv(send_buf, send_counts, recv_buf, recv_counts, comm, stream).wait(); /* open recv_buf and check its correctness on the device side */ diff --git a/examples/sycl/sycl_alltoallv_usm_test.cpp b/examples/sycl/sycl_alltoallv_usm_test.cpp index 211b548d1..89b884f13 100644 --- a/examples/sycl/sycl_alltoallv_usm_test.cpp +++ b/examples/sycl/sycl_alltoallv_usm_test.cpp @@ -88,7 +88,7 @@ int main(int argc, char *argv[]) { vector deps; deps.push_back(ccl::create_event(e)); - /* invoke alltoall */ + /* invoke alltoallv */ auto attr = ccl::create_operation_attr(); ccl::alltoallv(send_buf, send_counts, recv_buf, recv_counts, comm, stream, attr, deps).wait(); diff --git a/examples/sycl/sycl_reduce_inplace_usm_test.cpp b/examples/sycl/sycl_reduce_inplace_usm_test.cpp new file mode 100644 index 000000000..c28770120 --- /dev/null +++ b/examples/sycl/sycl_reduce_inplace_usm_test.cpp @@ -0,0 +1,137 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sycl_base.hpp" + +using namespace std; +using namespace sycl; + +int main(int argc, char* argv[]) { + const size_t count = 10 * 1024 * 1024; + int root_rank = 1; + + int size = 0; + int rank = 0; + + ccl::init(); + + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + + queue q; + if (!create_sycl_queue(argc, argv, rank, q)) { + return -1; + } + + buf_allocator allocator(q); + + auto usm_alloc_type = usm::alloc::shared; + if (argc > 2) { + usm_alloc_type = usm_alloc_type_from_string(argv[2]); + } + if (argc > 3) { + root_rank = atoi(argv[3]); + } + if (rank == root_rank) { + printf("root rank: %d\n", root_rank); + } + + if (!check_sycl_usm(q, usm_alloc_type)) { + return -1; + } + + /* create kvs */ + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + /* create communicator */ + auto dev = ccl::create_device(q.get_device()); + auto ctx = ccl::create_context(q.get_context()); + auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs); + + /* create stream */ + auto stream = ccl::create_stream(q); + + /* create buffers */ + auto buf = allocator.allocate(count, usm_alloc_type); + + /* open buffers and modify them on the device side */ + auto e = q.submit([&](auto& h) { + h.parallel_for(count, [=](auto id) { + buf[id] = rank + id + 1; + }); + }); + + int check_sum = 0; + for (int i = 1; i <= size; ++i) { + check_sum += i; + } + + /* do not wait completion of kernel and provide it as dependency for operation */ + vector deps; + deps.push_back(ccl::create_event(e)); + + /* invoke reduce */ + auto attr = ccl::create_operation_attr(); + ccl::reduce(buf, buf, count, ccl::reduction::sum, root_rank, comm, stream, attr, deps).wait(); + + /* open buf and check its correctness on the device side */ + buffer check_buf(count); + + q.submit([&](auto& h) { + accessor check_buf_acc(check_buf, h, write_only); + h.parallel_for(count, [=](auto id) { + int expected = (rank == root_rank) ? (check_sum + size * id) : (rank + id + 1); + if (buf[id] != expected) { + check_buf_acc[id] = -1; + } + else { + check_buf_acc[id] = 0; + } + }); + }); + + if (!handle_exception(q)) + return -1; + + /* print out the result of the test on the host side */ + { + host_accessor check_buf_acc(check_buf, read_only); + size_t i; + for (i = 0; i < count; i++) { + if (check_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; + } + } + if (i == count) { + cout << "PASSED\n"; + } + } + + return 0; +} diff --git a/examples/sycl/sycl_reduce_scatter_test.cpp b/examples/sycl/sycl_reduce_scatter_test.cpp new file mode 100644 index 000000000..e91df88f5 --- /dev/null +++ b/examples/sycl/sycl_reduce_scatter_test.cpp @@ -0,0 +1,128 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sycl_base.hpp" + +using namespace std; +using namespace sycl; + +int main(int argc, char *argv[]) { + const size_t count = 10 * 1024 * 1024; + + int size = 0; + int rank = 0; + + ccl::init(); + + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + + queue q; + if (!create_sycl_queue(argc, argv, rank, q)) { + return -1; + } + + /* create kvs */ + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + /* create communicator */ + auto dev = ccl::create_device(q.get_device()); + auto ctx = ccl::create_context(q.get_context()); + auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs); + + /* create stream */ + auto stream = ccl::create_stream(q); + + /* create buffers */ + buffer send_buf(count * size); + buffer expected_buf(count); + buffer recv_buf(count); + + { + /* open buffers and initialize them on the host side */ + host_accessor send_buf_acc(send_buf, write_only); + host_accessor recv_buf_acc(recv_buf, write_only); + host_accessor expected_acc_buf(expected_buf, write_only); + + for (size_t i = 0; i < count * size; i++) { + send_buf_acc[i] = rank; + } + for (size_t i = 0; i < count; i++) { + recv_buf_acc[i] = -1; + } + + for (size_t i = 0; i < count; i++) { + expected_acc_buf[i] = size * (size + 1) / 2; + } + } + + /* open send_buf and modify it on the device side */ + q.submit([&](auto &h) { + accessor send_buf_acc(send_buf, h, write_only); + h.parallel_for(count * size, [=](auto id) { + send_buf_acc[id] += 1; + }); + }); + + if (!handle_exception(q)) + return -1; + + /* invoke reduce_scatter */ + ccl::reduce_scatter(send_buf, recv_buf, count, ccl::reduction::sum, comm, stream).wait(); + + /* open recv_buf and check its correctness on the device side */ + q.submit([&](auto &h) { + accessor recv_buf_acc(recv_buf, h, write_only); + accessor expected_buf_acc(expected_buf, h, read_only); + h.parallel_for(count, [=](auto id) { + if (recv_buf_acc[id] != expected_buf_acc[id]) { + recv_buf_acc[id] = -1; + } + }); + }); + + if (!handle_exception(q)) + return -1; + + /* print out the result of the test on the host side */ + { + host_accessor recv_buf_acc(recv_buf, read_only); + size_t i; + for (i = 0; i < count; i++) { + if (recv_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; + } + } + if (i == count) { + cout << "PASSED\n"; + } + } + + return 0; +} diff --git a/examples/sycl/sycl_reduce_scatter_usm_test.cpp b/examples/sycl/sycl_reduce_scatter_usm_test.cpp new file mode 100644 index 000000000..fd953d314 --- /dev/null +++ b/examples/sycl/sycl_reduce_scatter_usm_test.cpp @@ -0,0 +1,130 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sycl_base.hpp" + +using namespace std; +using namespace sycl; + +int main(int argc, char *argv[]) { + const size_t count = 10 * 1024 * 1024; + + int size = 0; + int rank = 0; + + ccl::init(); + + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + + queue q; + if (!create_sycl_queue(argc, argv, rank, q)) { + return -1; + } + + buf_allocator allocator(q); + + auto usm_alloc_type = usm::alloc::shared; + if (argc > 2) { + usm_alloc_type = usm_alloc_type_from_string(argv[2]); + } + + if (!check_sycl_usm(q, usm_alloc_type)) { + return -1; + } + + /* create kvs */ + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void *)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + /* create communicator */ + auto dev = ccl::create_device(q.get_device()); + auto ctx = ccl::create_context(q.get_context()); + auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs); + + /* create stream */ + auto stream = ccl::create_stream(q); + + /* create buffers */ + auto send_buf = allocator.allocate(count * size, usm_alloc_type); + auto recv_buf = allocator.allocate(count, usm_alloc_type); + + buffer expected_buf(count); + buffer check_buf(count); + + /* open buffers and modify them on the device side */ + auto e = q.submit([&](auto &h) { + accessor expected_buf_acc(expected_buf, h, write_only); + h.parallel_for(count, [=](auto id) { + recv_buf[id] = -1; + expected_buf_acc[id] = size * (size - 1) / 2; + for (int i = 0; i < size; i++) { + send_buf[i * count + id] = rank; + } + }); + }); + + /* do not wait completion of kernel and provide it as dependency for operation */ + vector deps; + deps.push_back(ccl::create_event(e)); + + /* invoke reduce_scatter */ + auto attr = ccl::create_operation_attr(); + ccl::reduce_scatter(send_buf, recv_buf, count, ccl::reduction::sum, comm, stream, attr, deps) + .wait(); + + /* open recv_buf and check its correctness on the device side */ + q.submit([&](auto &h) { + accessor expected_buf_acc(expected_buf, h, read_only); + accessor check_buf_acc(check_buf, h, write_only); + h.parallel_for(count, [=](auto id) { + if (recv_buf[id] != expected_buf_acc[id]) { + check_buf_acc[id] = -1; + } + }); + }); + + if (!handle_exception(q)) + return -1; + + /* print out the result of the test on the host side */ + { + host_accessor check_buf_acc(check_buf, read_only); + size_t i; + for (i = 0; i < count; i++) { + if (check_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; + } + } + if (i == count) { + cout << "PASSED\n"; + } + } + + return 0; +} diff --git a/examples/sycl/sycl_reduce_usm_test.cpp b/examples/sycl/sycl_reduce_usm_test.cpp new file mode 100644 index 000000000..dc9dbd169 --- /dev/null +++ b/examples/sycl/sycl_reduce_usm_test.cpp @@ -0,0 +1,140 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "sycl_base.hpp" + +using namespace std; +using namespace sycl; + +int main(int argc, char* argv[]) { + const size_t count = 10 * 1024 * 1024; + int root_rank = 1; + + int size = 0; + int rank = 0; + + ccl::init(); + + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(mpi_finalize); + + queue q; + if (!create_sycl_queue(argc, argv, rank, q)) { + return -1; + } + + buf_allocator allocator(q); + + auto usm_alloc_type = usm::alloc::shared; + if (argc > 2) { + usm_alloc_type = usm_alloc_type_from_string(argv[2]); + } + if (argc > 3) { + root_rank = atoi(argv[3]); + } + if (rank == root_rank) { + printf("root rank: %d\n", root_rank); + } + + if (!check_sycl_usm(q, usm_alloc_type)) { + return -1; + } + + /* create kvs */ + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + } + else { + MPI_Bcast((void*)main_addr.data(), main_addr.size(), MPI_BYTE, 0, MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + /* create communicator */ + auto dev = ccl::create_device(q.get_device()); + auto ctx = ccl::create_context(q.get_context()); + auto comm = ccl::create_communicator(size, rank, dev, ctx, kvs); + + /* create stream */ + auto stream = ccl::create_stream(q); + + /* create buffers */ + auto send_buf = allocator.allocate(count, usm_alloc_type); + auto recv_buf = allocator.allocate(count, usm_alloc_type); + + /* open buffers and modify them on the device side */ + auto e = q.submit([&](auto& h) { + h.parallel_for(count, [=](auto id) { + send_buf[id] = rank + id + 1; + recv_buf[id] = -1; + }); + }); + + int check_sum = 0; + for (int i = 1; i <= size; ++i) { + check_sum += i; + } + + /* do not wait completion of kernel and provide it as dependency for operation */ + vector deps; + deps.push_back(ccl::create_event(e)); + + /* invoke reduce */ + auto attr = ccl::create_operation_attr(); + ccl::reduce(send_buf, recv_buf, count, ccl::reduction::sum, root_rank, comm, stream, attr, deps) + .wait(); + + /* open recv_buf and check its correctness on the device side */ + buffer check_buf(count); + + q.submit([&](auto& h) { + accessor check_buf_acc(check_buf, h, write_only); + h.parallel_for(count, [=](auto id) { + int expected = (rank == root_rank) ? (check_sum + size * id) : -1; + if (recv_buf[id] != expected) { + check_buf_acc[id] = -1; + } + else { + check_buf_acc[id] = 0; + } + }); + }); + + if (!handle_exception(q)) + return -1; + + /* print out the result of the test on the host side */ + { + host_accessor check_buf_acc(check_buf, read_only); + size_t i; + for (i = 0; i < count; i++) { + if (check_buf_acc[i] == -1) { + cout << "FAILED\n"; + break; + } + } + if (i == count) { + cout << "PASSED\n"; + } + } + + return 0; +} diff --git a/include/oneapi/ccl/coll_attr.hpp b/include/oneapi/ccl/coll_attr.hpp index a99454024..43d68e9b8 100644 --- a/include/oneapi/ccl/coll_attr.hpp +++ b/include/oneapi/ccl/coll_attr.hpp @@ -333,7 +333,6 @@ class barrier_attr : public ccl_api_base_copyable::type& version); - ; }; /** @@ -513,7 +512,6 @@ class reduce_scatter_attr : public ccl_api_base_copyable::type& version); - ; }; /** @@ -574,7 +572,6 @@ class sparse_allreduce_attr : public ccl_api_base_copyable::type& version); - ; }; /** diff --git a/include/oneapi/ccl/config.h.in b/include/oneapi/ccl/config.h.in index 1f2ad4e5f..6fe6015ea 100644 --- a/include/oneapi/ccl/config.h.in +++ b/include/oneapi/ccl/config.h.in @@ -30,19 +30,15 @@ #define ONECCL_SPEC_VERSION "1.0" -#define CCL_MAJOR_VERSION @CCL_MAJOR_VERSION@ -#define CCL_MINOR_VERSION @CCL_MINOR_VERSION@ -#define CCL_UPDATE_VERSION @CCL_UPDATE_VERSION@ -#cmakedefine CCL_PRODUCT_STATUS "@CCL_PRODUCT_STATUS@" -#cmakedefine CCL_PRODUCT_BUILD_DATE "@CCL_PRODUCT_BUILD_DATE@" -#cmakedefine CCL_PRODUCT_FULL "@CCL_PRODUCT_FULL@" +#define CCL_MAJOR_VERSION @CCL_MAJOR_VERSION@ +#define CCL_MINOR_VERSION @CCL_MINOR_VERSION@ +#define CCL_UPDATE_VERSION @CCL_UPDATE_VERSION@ +#cmakedefine CCL_PRODUCT_STATUS "@CCL_PRODUCT_STATUS@" +#cmakedefine CCL_PRODUCT_BUILD_DATE "@CCL_PRODUCT_BUILD_DATE@" +#cmakedefine CCL_PRODUCT_FULL "@CCL_PRODUCT_FULL@" /* Auto-generated configuration settings for SYCL support */ #cmakedefine CCL_ENABLE_SYCL -#ifdef CCL_ENABLE_SYCL -@CCL_ENABLE_SYCL_CHECK_CONTRACT@ -#endif - -/* Auto-generated configuration settings for multi GPU support*/ -#cmakedefine MULTI_GPU_SUPPORT +/* Auto-generated configuration settings for Level Zero support */ +#cmakedefine CCL_ENABLE_ZE diff --git a/include/oneapi/ccl/device_types.hpp b/include/oneapi/ccl/device_types.hpp index bcf807348..35367af55 100644 --- a/include/oneapi/ccl/device_types.hpp +++ b/include/oneapi/ccl/device_types.hpp @@ -27,7 +27,7 @@ namespace ccl { using process_id = size_t; using host_id = std::string; -#ifdef MULTI_GPU_SUPPORT +#ifdef CCL_ENABLE_ZE constexpr size_t CCL_GPU_DEVICES_AFFINITY_MASK_SIZE = 4; using device_mask_t = std::bitset; using process_aggregated_device_mask_t = std::map; diff --git a/include/oneapi/ccl/native_device_api/empty/export.hpp b/include/oneapi/ccl/native_device_api/empty/export.hpp index 44a239c0b..bf6190a3d 100644 --- a/include/oneapi/ccl/native_device_api/empty/export.hpp +++ b/include/oneapi/ccl/native_device_api/empty/export.hpp @@ -34,7 +34,7 @@ struct backend_info { return CL_BACKEND_TYPE; } static constexpr const char* name() { - return "BACKEND_UNAVAILABLE"; + return "EMPTY"; } }; @@ -59,6 +59,7 @@ struct generic_context_type { using impl_t = native::ccl_context; using ccl_native_t = std::shared_ptr; + generic_context_type() = default; template generic_context_type(T&& not_used) { (void)not_used; diff --git a/include/oneapi/ccl/native_device_api/export_api.hpp b/include/oneapi/ccl/native_device_api/export_api.hpp index f4b01b859..17347caae 100644 --- a/include/oneapi/ccl/native_device_api/export_api.hpp +++ b/include/oneapi/ccl/native_device_api/export_api.hpp @@ -17,13 +17,13 @@ #include "oneapi/ccl/config.h" #ifdef CCL_ENABLE_SYCL -#ifdef MULTI_GPU_SUPPORT +#ifdef CCL_ENABLE_ZE #include "sycl_l0/export.hpp" #else #include "sycl/export.hpp" #endif #else -#ifdef MULTI_GPU_SUPPORT +#ifdef CCL_ENABLE_ZE #include "l0/export.hpp" #else #include "empty/export.hpp" diff --git a/include/oneapi/ccl/native_device_api/l0/export.hpp b/include/oneapi/ccl/native_device_api/l0/export.hpp index c56d9d9c4..1e7921ae6 100644 --- a/include/oneapi/ccl/native_device_api/l0/export.hpp +++ b/include/oneapi/ccl/native_device_api/l0/export.hpp @@ -31,7 +31,7 @@ struct backend_info { return CL_BACKEND_TYPE; } static constexpr const char* name() { - return "LEVEL_ZERO_BACKEND"; + return "LEVEL_ZERO"; } }; diff --git a/include/oneapi/ccl/native_device_api/sycl/export.hpp b/include/oneapi/ccl/native_device_api/sycl/export.hpp index ba3ce9a34..f205e825b 100644 --- a/include/oneapi/ccl/native_device_api/sycl/export.hpp +++ b/include/oneapi/ccl/native_device_api/sycl/export.hpp @@ -29,7 +29,7 @@ struct backend_info { return CL_BACKEND_TYPE; } static constexpr const char* name() { - return "DPCPP_BACKEND"; + return "DPCPP"; } }; diff --git a/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp b/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp index ae36064a3..4bae8ed52 100644 --- a/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp +++ b/include/oneapi/ccl/native_device_api/sycl_l0/export.hpp @@ -30,7 +30,7 @@ struct backend_info { return CL_BACKEND_TYPE; } static constexpr const char* name() { - return "DPCPP_LEVEL_ZERO_BACKEND"; + return "DPCPP_LEVEL_ZERO"; } }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b262ba7bc..72e1eda33 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,112 +19,63 @@ set (EXTENSIONS_SRC) if (CCL_ENABLE_SYCL) list (APPEND EXTENSIONS_SRC - native_device_api/l0/utils.cpp - native_device_api/sycl/export.cpp - native_device_api/interop_utils.cpp + native_device_api/l0/utils.cpp + native_device_api/sycl/export.cpp + native_device_api/interop_utils.cpp ) endif(CCL_ENABLE_SYCL) -if (MULTI_GPU_SUPPORT) -list (APPEND EXTENSIONS_SRC - ccl_cpp_utils.cpp - - native_device_api/l0/base.cpp - native_device_api/l0/device.cpp - native_device_api/l0/context.cpp - native_device_api/l0/event_pool.cpp - native_device_api/l0/subdevice.cpp - native_device_api/l0/driver.cpp - native_device_api/l0/export.cpp - native_device_api/l0/platform.cpp - native_device_api/l0/utils.cpp - native_device_api/l0/primitives.cpp - native_device_api/interop_utils.cpp - - common/comm/l0/comm_context.cpp - common/comm/l0/comm_context_storage.cpp - common/comm/l0/context_comm_addr.cpp - - common/comm/l0/devices/ccl_gpu_base_comm.cpp - common/comm/l0/devices/ccl_gpu_comm.cpp - common/comm/l0/devices/ccl_virtual_gpu_comm.cpp - common/comm/l0/devices/ccl_ipc_gpu_comm.cpp - - common/comm/l0/devices/communication_structs/connection.cpp - common/comm/l0/devices/communication_structs/ipc_connection.cpp - common/comm/l0/devices/communication_structs/ipc_server.cpp - common/comm/l0/devices/communication_structs/ipc_client.cpp - - common/comm/l0/context/process_group_ctx.cpp - common/comm/l0/context/thread_group_ctx.cpp - common/comm/l0/context/device_group_ctx.cpp - common/comm/l0/context/device_storage.cpp - - common/comm/l0/topology/topology_serializer.cpp - common/comm/l0/topology/ring/device_group_ring_creator.cpp - common/comm/l0/topology/ring/thread_group_ring_creator.cpp - common/comm/l0/topology/ring/process_group_ring_creator.cpp - common/comm/l0/topology/topology_construction_utils.cpp - - common/comm/l0/context/scale/ipc/ipc_ctx_session.cpp - common/comm/l0/context/scale/ipc/ipc_ctx_utils.cpp - common/comm/l0/context/scale/ipc/ipc_session_key.cpp - - common/comm/l0/context/scale/base/base_session.cpp - common/comm/l0/context/scale/scale_out/scale_out_session.cpp - - common/comm/l0/gpu_comm_attr.cpp - common/comm/l0/modules/base_entry_module.cpp - common/comm/l0/modules/modules_source_data.cpp - common/comm/l0/modules/kernel_utils.cpp - - sched/gpu_sched.cpp - sched/gpu_concurrent_sched.cpp - sched/entry/gpu/ze_cache.cpp - sched/entry/gpu/ze_call.cpp - sched/entry/gpu/ze_primitives.cpp) -endif(MULTI_GPU_SUPPORT) - -if (CCL_ENABLE_SYCL AND MULTI_GPU_SUPPORT) +if (CCL_ENABLE_SYCL AND CCL_ENABLE_ZE) list (APPEND EXTENSIONS_SRC - sched/entry/gpu/ze_base_entry.cpp - sched/entry/gpu/ze_allreduce_entry.cpp - sched/entry/gpu/ze_copy_entry.cpp - sched/entry/gpu/ze_handle_exchange_entry.cpp - sched/entry/gpu/ze_event_signal_entry.cpp - sched/entry/gpu/ze_event_wait_entry.cpp - sched/entry/gpu/ze_reduce_entry.cpp - sched/entry/reduce_local_entry.cpp - sched/ze_handle_manager.cpp + + ccl_cpp_utils.cpp + + native_device_api/l0/base.cpp + native_device_api/l0/device.cpp + native_device_api/l0/context.cpp + native_device_api/l0/event_pool.cpp + native_device_api/l0/subdevice.cpp + native_device_api/l0/driver.cpp + native_device_api/l0/export.cpp + native_device_api/l0/platform.cpp + native_device_api/l0/utils.cpp + native_device_api/l0/primitives.cpp + native_device_api/interop_utils.cpp + + sched/entry/ze/allreduce/ze_a2a_allreduce_entry.cpp + sched/entry/ze/allreduce/ze_onesided_allreduce_entry.cpp + sched/entry/ze/allreduce/ze_ring_allreduce_entry.cpp + sched/entry/ze/ze_a2a_allgatherv_entry.cpp + sched/entry/ze/ze_a2a_gatherv_entry.cpp + sched/entry/ze/ze_a2a_reduce_scatter_entry.cpp + sched/entry/ze/ze_base_entry.cpp + sched/entry/ze/ze_barrier_entry.cpp + sched/entry/ze/ze_cache.cpp + sched/entry/ze/ze_call.cpp + sched/entry/ze/ze_copy_entry.cpp + sched/entry/ze/ze_handle_exchange_entry.cpp + sched/entry/ze/ze_event_signal_entry.cpp + sched/entry/ze/ze_event_wait_entry.cpp + sched/entry/ze/ze_onesided_reduce_entry.cpp + sched/entry/ze/ze_primitives.cpp + sched/entry/ze/ze_reduce_local_entry.cpp + + sched/ze/ze_event_manager.cpp + sched/ze/ze_handle_manager.cpp + sched/ze/ze_ipc_event_pool_manager.cpp + sched/ze/ze_list_manager.cpp + + coll/coll_util.cpp ) -endif(CCL_ENABLE_SYCL AND MULTI_GPU_SUPPORT) +endif(CCL_ENABLE_SYCL AND CCL_ENABLE_ZE) set(CCL_SRC - ccl_cpp_communicator.cpp - ccl_cpp_environment.cpp - ccl_api_functions.cpp - ccl_app_api_coll_attr.cpp - ccl_app_api_comm_attr.cpp - ccl_app_api_comm_split_attr.cpp - ccl_app_api_datatype_attr.cpp - ccl_app_api_kvs_attr.cpp - ccl_app_api_event.cpp - ccl_app_api_init_attr.cpp - ccl_cpp_kvs.cpp - ccl_cpp_device.cpp - ccl_cpp_stream.cpp - ccl_cpp_context.cpp - ccl_cpp_utils.cpp - ccl_empty_attr.cpp - ccl_empty_coll_attr.cpp - ccl_empty_comm_attr.cpp - ccl_empty_init_attr.cpp - ccl_empty_comm_split_attr.cpp - ccl_empty_kvs_attr.cpp - ccl_empty_stream.cpp - native_device_api/sycl_l0/export.cpp - native_device_api/empty/export.cpp - atl/atl_wrapper.cpp + + atl/atl_base_comm.cpp + atl/atl_def.cpp + atl/ofi/atl_ofi_comm.cpp + atl/mpi/atl_mpi_comm.cpp + atl/mpi/atl_mpi_global_data.cpp atl/mpi/atl_mpi.cpp atl/ofi/atl_ofi.cpp atl/ofi/atl_ofi_helper.cpp @@ -134,16 +85,15 @@ set(CCL_SRC atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs_keeper.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp - atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp - atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp atl/util/pm/pmi_rt/pmi_simple.cpp atl/util/pm/pmi_rt/pmi/simple_pmi.c atl/util/pm/pmi_rt/pmi/simple_pmiutil.c + coll/coll_common_attributes.cpp coll/ccl_allgather_op_attr.cpp coll/ccl_allreduce_op_attr.cpp @@ -179,59 +129,93 @@ set(CCL_SRC coll/selection/selector_reduce.cpp coll/selection/selector_reduce_scatter.cpp coll/selection/selector_sparse_allreduce.cpp + + common/comm/atl_tag.cpp + common/comm/comm.cpp + common/comm/compiler_comm_interface_dispatcher.cpp + common/context/context.cpp + common/datatype/datatype.cpp + common/device/device.cpp + common/env/env.cpp + common/event/ccl_event.cpp + common/event/impls/host_event.cpp + common/event/impls/native_event.cpp + common/framework/framework.cpp + common/global/global.cpp + common/log/log.cpp + common/request/request.cpp + common/stream/stream.cpp + common/utils/memcpy.cpp + common/utils/spinlock.cpp + common/utils/version.cpp + common/utils/yield.cpp + comp/bf16/bf16.cpp comp/bf16/bf16_intrisics.cpp comp/comp.cpp comp/fp16/fp16.cpp comp/fp16/fp16_intrisics.cpp + + exec/exec.cpp + exec/thread/base_thread.cpp + exec/thread/listener.cpp + exec/thread/service_worker.cpp + exec/thread/worker.cpp + + fusion/fusion.cpp + hwloc/hwloc_wrapper.cpp - sched/sched.cpp - sched/buffer_cache.cpp - sched/extra_sched.cpp - sched/master_sched.cpp - sched/sched_base.cpp - sched/sched_timer.cpp + + native_device_api/sycl_l0/export.cpp + native_device_api/empty/export.cpp + + parallelizer/parallelizer.cpp + + sched/buffer/buffer_cache.cpp + sched/buffer/buffer_manager.cpp sched/cache/cache.cpp sched/cache/key.cpp - sched/queue/flow_control.cpp - sched/queue/strict_queue.cpp - sched/queue/queue.cpp sched/entry/coll/coll_entry.cpp sched/entry/coll/coll_entry_helper.cpp sched/entry/copy/copy_entry.cpp sched/entry/copy/copy_helper.cpp sched/entry/entry.cpp sched/entry/factory/chunked_entry_factory.cpp - exec/exec.cpp - exec/thread/base_thread.cpp - exec/thread/listener.cpp - exec/thread/service_worker.cpp - exec/thread/worker.cpp - fusion/fusion.cpp - parallelizer/parallelizer.cpp - unordered_coll/unordered_coll.cpp - - common/comm/atl_tag.cpp - common/comm/comm.cpp - common/comm/compiler_comm_interface_dispatcher.cpp - common/comm/host_communicator/host_communicator.cpp + sched/entry/recv_copy_entry.cpp + sched/entry/reduce_local_entry.cpp + sched/queue/flow_control.cpp + sched/queue/queue.cpp + sched/queue/strict_queue.cpp + sched/extra_sched.cpp + sched/master_sched.cpp + sched/sched.cpp + sched/sched_base.cpp + sched/sched_timer.cpp - common/context/context.cpp - common/datatype/datatype.cpp - common/device/device.cpp - common/event/ccl_event.cpp - common/stream/stream.cpp + unordered_coll/unordered_coll.cpp - common/env/env.cpp - common/global/global.cpp - common/log/log.cpp - common/event/impls/host_event.cpp - common/event/impls/native_event.cpp - common/framework/framework.cpp - common/request/request.cpp - common/utils/spinlock.cpp - common/utils/version.cpp - common/utils/yield.cpp + ccl_api_functions.cpp + ccl_app_api_coll_attr.cpp + ccl_app_api_comm_attr.cpp + ccl_app_api_comm_split_attr.cpp + ccl_app_api_datatype_attr.cpp + ccl_app_api_event.cpp + ccl_app_api_init_attr.cpp + ccl_app_api_kvs_attr.cpp + ccl_cpp_communicator.cpp + ccl_cpp_context.cpp + ccl_cpp_device.cpp + ccl_cpp_environment.cpp + ccl_cpp_kvs.cpp + ccl_cpp_stream.cpp + ccl_cpp_utils.cpp + ccl_empty_attr.cpp + ccl_empty_coll_attr.cpp + ccl_empty_comm_attr.cpp + ccl_empty_comm_split_attr.cpp + ccl_empty_init_attr.cpp + ccl_empty_kvs_attr.cpp + ccl_empty_stream.cpp ${EXTENSIONS_SRC}) @@ -262,6 +246,8 @@ if (${CMAKE_C_COMPILER_ID} STREQUAL "Intel" OR ${CMAKE_CXX_COMPILER_ID} STREQUAL set(SRC_C_FLAGS "${SRC_C_FLAGS} -prof-gen=srcpos -prof-src-root-cwd") set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -prof-gen=srcpos -prof-src-root-cwd") endif() + #To suppress for 'offsetof applied to non-POD (Plain Old Data) types is nonstandar' + set(SRC_CXX_FLAGS "${SRC_CXX_FLAGS} -diag-disable=1875") endif() list(APPEND SRC_INCLUDE_DIRS @@ -308,11 +294,6 @@ endif() add_library(ccl SHARED $) target_include_directories(ccl PUBLIC ${SRC_INCLUDE_DIRS}) -# link with release_mt libmpi.so for oneAPI Base toolkit -# libccl.so -> cpu_icc/cpu_gpu_dpcpp -> lib -> latest -> ccl -> mpi -> ... -set(ONEAPI_IMPI_RPATH "'$ORIGIN'/../../../../mpi/latest/lib/release_mt/") -set_target_properties(ccl PROPERTIES LINK_FLAGS "-Wl,-rpath,${ONEAPI_IMPI_RPATH}") - target_link_libraries(ccl PUBLIC ${SRC_LINK_LIBS}) if (NOT LIB_SO_VERSION AND NOT LIB_MAJOR_VERSION) @@ -330,7 +311,6 @@ message(STATUS "SRC LINK_LIBS: ${SRC_LINK_LIBS}") install(TARGETS ccl LIBRARY DESTINATION ${CCL_INSTALL_LIB}) install(FILES - "${PROJECT_SOURCE_DIR}/cmake/FindComputeCpp.cmake" "${PROJECT_SOURCE_DIR}/cmake/FindIntelSYCL.cmake" "${PROJECT_SOURCE_DIR}/cmake/FindIntelSYCL_level_zero.cmake" "${PROJECT_SOURCE_DIR}/cmake/Findlevel_zero.cmake" diff --git a/src/atl/atl.h b/src/atl/atl.h deleted file mode 100644 index c515fa483..000000000 --- a/src/atl/atl.h +++ /dev/null @@ -1,154 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include -#include -#include - -#include "atl_def.h" -#include "common/log/log.hpp" -#include "util/pm/pm_rt.h" - -#ifdef __cplusplus -class iatl { -public: - virtual ~iatl() = default; - - virtual atl_status_t atl_init(int* argc, - char*** argv, - atl_attr_t* attr, - const char* main_addr, - std::unique_ptr& pmi) = 0; - - virtual atl_status_t atl_finalize() = 0; - - virtual atl_status_t atl_update(std::unique_ptr& pmi) = 0; - - virtual atl_ep_t** atl_get_eps() = 0; - - virtual atl_proc_coord_t* atl_get_proc_coord() = 0; - - virtual atl_status_t atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) = 0; - - virtual atl_status_t atl_mr_dereg(atl_mr_t* mr) = 0; - - virtual atl_status_t atl_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) = 0; - - virtual atl_status_t atl_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - int len, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - int root, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_wait(atl_ep_t* ep, atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_wait_all(atl_ep_t* ep, atl_req_t* req, size_t count) = 0; - - virtual atl_status_t atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) = 0; - - virtual atl_status_t atl_ep_poll(atl_ep_t* ep) = 0; - - virtual atl_status_t atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) = 0; - virtual bool is_inited() = 0; -}; -#endif diff --git a/src/atl/atl_base_comm.cpp b/src/atl/atl_base_comm.cpp new file mode 100644 index 000000000..3388da829 --- /dev/null +++ b/src/atl/atl_base_comm.cpp @@ -0,0 +1,157 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#ifdef CCL_ENABLE_MPI +#include "atl/mpi/atl_mpi.hpp" +#include "atl/mpi/atl_mpi_comm.hpp" +#endif // CCL_ENABLE_MPI + +#include "atl/atl_base_comm.hpp" +#include "atl/ofi/atl_ofi_comm.hpp" +#include "atl/ofi/atl_ofi.hpp" +#include "atl/util/pm/pm_rt.h" +#include "exec/exec.hpp" + +atl_attr_t atl_base_comm::attr = { + /* in */ + { + 0, /* enable_shm */ + 0, /* enable_rma */ + 0, /* enable_hmem */ + 0, /* enable_sync_coll */ + 0, /* enable_extra_ep */ + 1, /* ep_count */ + ATL_MNIC_NONE, /* mnic_type */ + "", /* mnic_name */ + 1, /* mnic_count */ + ATL_MNIC_OFFSET_NONE /* mnic_offset */ + }, + + /* out */ + { + 0, /* enable_shm */ + 0, /* enable_rma */ + 0, /* enable_hmem */ + ATL_MNIC_NONE, /* mnic_type */ + 0, /* mnic_count */ + 0, /* tag_bits */ + 0, /* max_tag */ + 0, /* max_order_waw_size */ + } +}; + +ccl_executor* atl_base_comm::executor = nullptr; + +void atl_base_comm::init_tag() { + tag = std::unique_ptr(new ccl_atl_tag(attr.out.tag_bits, attr.out.max_tag)); + if (rank == 0) { + LOG_DEBUG("atl tag: ", tag->to_string()); + } +} + +void atl_base_comm::print_atl_attrs() { + std::stringstream ss; + + ss << "atl attrs:\n{\n" + << " in: { " + << "shm: " << attr.in.enable_shm << ", hmem: " << attr.in.enable_hmem + << ", sync_coll: " << attr.in.enable_sync_coll << ", extra_ep: " << attr.in.enable_extra_ep + << ", ep_count: " << attr.in.ep_count << ", mnic_type: " << to_string(attr.in.mnic_type) + << ", mnic_count: " << attr.in.mnic_count + << ", mnic_offset: " << to_string(attr.in.mnic_offset) << " }\n" + << " out: { " + << "shm: " << attr.out.enable_shm << ", hmem: " << attr.out.enable_hmem + << ", mnic_type: " << to_string(attr.out.mnic_type) + << ", mnic_count: " << attr.out.mnic_count << ", tag_bits: " << attr.out.tag_bits + << ", max_tag: " << attr.out.max_tag << " }\n}"; + + LOG_INFO(ss.str()); +} + +void atl_base_comm::executor_update() { + if (!executor->are_workers_started()) { + if (rank < coord.local_count) + LOG_INFO( + "start workers for local process [", coord.local_idx, ":", coord.local_count, "]"); + executor->start_workers(coord.local_idx, coord.local_count); + } +} + +std::shared_ptr atl_comm_manager::create_comm() { + std::shared_ptr atl_comm; + + auto transport_type = ccl::global_data::env().atl_transport; + + switch (transport_type) { + case ccl_atl_ofi: atl_comm = std::shared_ptr(new atl_ofi_comm()); break; +#ifdef CCL_ENABLE_MPI + case ccl_atl_mpi: atl_comm = std::shared_ptr(new atl_mpi_comm()); break; +#endif // CCL_ENABLE_MPI + default: LOG_ERROR("Unsupported yet"); break; + } + return atl_comm; +} + +std::shared_ptr atl_comm_manager::create_comm(std::shared_ptr k) { + std::shared_ptr atl_comm; + + auto transport_type = ccl::global_data::env().atl_transport; + + switch (transport_type) { + case ccl_atl_ofi: atl_comm = std::shared_ptr(new atl_ofi_comm(k)); break; +#ifdef CCL_ENABLE_MPI + case ccl_atl_mpi: atl_comm = std::shared_ptr(new atl_mpi_comm(k)); break; +#endif // CCL_ENABLE_MPI + default: LOG_ERROR("Unsupported yet"); break; + } + return atl_comm; +} + +std::shared_ptr atl_comm_manager::create_comm(int total_rank_count, + const std::vector& ranks, + std::shared_ptr k) { + std::shared_ptr atl_comm; + + auto transport_type = ccl::global_data::env().atl_transport; + + switch (transport_type) { + case ccl_atl_ofi: + atl_comm = std::shared_ptr(new atl_ofi_comm(total_rank_count, ranks, k)); + break; +#ifdef CCL_ENABLE_MPI + case ccl_atl_mpi: + atl_comm = std::shared_ptr(new atl_mpi_comm(total_rank_count, ranks, k)); + break; +#endif // CCL_ENABLE_MPI + default: LOG_ERROR("Unsupported yet"); break; + } + return atl_comm; +} + +void atl_comm_manager::set_internal_env(const atl_attr_t& attr) { + auto transport_type = ccl::global_data::env().atl_transport; + atl_base_comm::attr = attr; + + if (transport_type == ccl_atl_ofi) + atl_ofi::atl_set_env(attr); +#ifdef CCL_ENABLE_MPI + else if (transport_type == ccl_atl_mpi) + atl_mpi::set_env(attr); +#endif // CCL_ENABLE_MPI +} + +void atl_comm_manager::set_exec(ccl_executor* exec) { + atl_base_comm::executor = exec; +} diff --git a/src/atl/atl_base_comm.hpp b/src/atl/atl_base_comm.hpp new file mode 100644 index 000000000..189dbcf65 --- /dev/null +++ b/src/atl/atl_base_comm.hpp @@ -0,0 +1,209 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include +#include +#include +#include + +#include "atl/atl_def.h" +#include "common/comm/atl_tag.hpp" +#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h" +#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" +#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h" + +class ccl_executor; + +class atl_base_comm { +protected: + atl_base_comm() = default; + +public: + virtual ~atl_base_comm() = default; + + virtual atl_status_t main_addr_reserve(char* main_addr) = 0; + + virtual atl_status_t finalize() = 0; + + virtual atl_status_t update() = 0; + + virtual atl_status_t wait_notification() = 0; + + virtual atl_status_t set_resize_function(atl_resize_fn_t fn) = 0; + + virtual atl_status_t mr_reg(const void* buf, size_t len, atl_mr_t** mr) = 0; + + virtual atl_status_t mr_dereg(atl_mr_t* mr) = 0; + + virtual atl_status_t send(size_t ep_idx, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req) = 0; + + virtual atl_status_t recv(size_t ep_idx, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req) = 0; + + virtual atl_status_t probe(size_t ep_idx, + int src_proc_idx, + uint64_t tag, + int* found, + size_t* recv_len) = 0; + + virtual atl_status_t allgatherv(size_t ep_idx, + const void* send_buf, + size_t send_len, + void* recv_buf, + const int* recv_lens, + const int* offsets, + atl_req_t* req) = 0; + + virtual atl_status_t allreduce(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) = 0; + + virtual atl_status_t alltoall(size_t ep_idx, + const void* send_buf, + void* recv_buf, + int len, + atl_req_t* req) = 0; + + virtual atl_status_t alltoallv(size_t ep_idx, + const void* send_buf, + const int* send_lens, + const int* send_offsets, + void* recv_buf, + const int* recv_lens, + const int* recv_offsets, + atl_req_t* req) = 0; + + virtual atl_status_t barrier(size_t ep_idx, atl_req_t* req) = 0; + + virtual atl_status_t bcast(size_t ep_idx, void* buf, size_t len, int root, atl_req_t* req) = 0; + + virtual atl_status_t reduce(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t len, + int root, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) = 0; + + virtual atl_status_t reduce_scatter(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t recv_len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) = 0; + + virtual atl_status_t read(size_t ep_idx, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) = 0; + + virtual atl_status_t write(size_t ep_idx, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) = 0; + + virtual atl_status_t wait(size_t ep_idx, atl_req_t* req) = 0; + + virtual atl_status_t wait_all(size_t ep_idx, atl_req_t* req, size_t count) = 0; + + virtual atl_status_t cancel(size_t ep_idx, atl_req_t* req) = 0; + + virtual atl_status_t poll(size_t ep_idx) = 0; + + virtual atl_status_t check(size_t ep_idx, atl_req_t* req) = 0; + + virtual size_t get_threads_per_process() = 0; + + virtual size_t get_ranks_per_process() = 0; + + virtual int get_rank() = 0; + + virtual int get_size() = 0; + + virtual int get_r2r_color() = 0; + + virtual int get_host_color() = 0; + + virtual std::shared_ptr comm_split(int color) = 0; + + virtual std::vector get_rank2rank_map() = 0; + + /* + * TODO: Temporary change. + * Need to define correct to unique id + */ + virtual size_t get_id() = 0; + std::unique_ptr tag; + static atl_attr_t attr; + +protected: + void init_tag(); + void print_atl_attrs(); + void executor_update(); + + friend class atl_comm_manager; + static ccl_executor* executor; + + int rank; + int size; + + size_t threads_per_process; + size_t ranks_per_process; + + std::vector rank2rank_map; + atl_proc_coord_t coord; + int parent_rank; + int parent_size; + + std::shared_ptr pmi; +}; + +class atl_comm_manager { +public: + static std::shared_ptr create_comm(); + + static std::shared_ptr create_comm(std::shared_ptr k); + + static std::shared_ptr create_comm(int total_rank_count, + const std::vector& ranks, + std::shared_ptr k); + static void set_internal_env(const atl_attr_t& attr); + static void set_exec(ccl_executor* exec); +}; diff --git a/src/atl/atl_def.cpp b/src/atl/atl_def.cpp new file mode 100644 index 000000000..16335b70d --- /dev/null +++ b/src/atl/atl_def.cpp @@ -0,0 +1,45 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "atl/atl_def.h" + +std::map mnic_type_names = { std::make_pair(ATL_MNIC_NONE, "none"), + std::make_pair(ATL_MNIC_LOCAL, "local"), + std::make_pair(ATL_MNIC_GLOBAL, "global") }; + +std::map mnic_offset_names = { + std::make_pair(ATL_MNIC_OFFSET_NONE, "none"), + std::make_pair(ATL_MNIC_OFFSET_LOCAL_PROC_IDX, "local_proc_idx") +}; + +std::string to_string(atl_mnic_t type) { + auto it = mnic_type_names.find(type); + if (it != mnic_type_names.end()) { + return it->second; + } + else { + return "unknown"; + } +} + +std::string to_string(atl_mnic_offset_t offset) { + auto it = mnic_offset_names.find(offset); + if (it != mnic_offset_names.end()) { + return it->second; + } + else { + return "unknown"; + } +} diff --git a/src/atl/atl_def.h b/src/atl/atl_def.h index 3e02a0aa8..557c164f2 100644 --- a/src/atl/atl_def.h +++ b/src/atl/atl_def.h @@ -15,10 +15,16 @@ */ #pragma once +#include +#include +#include +#include #include #include #include +#include "common/log/log.hpp" + #ifndef container_of #define container_of(ptr, type, field) ((type*)((char*)ptr - offsetof(type, field))) #endif @@ -45,11 +51,36 @@ * This is invoked by the ATL framework when the transport library is loaded. */ +#define ATL_CHECK_STATUS(expr, str) \ + do { \ + if (expr != ATL_STATUS_SUCCESS) { \ + LOG_ERROR(str); \ + return ATL_STATUS_FAILURE; \ + } \ + } while (0) + +#define KVS_2_ATL_CHECK_STATUS(expr, str) \ + do { \ + if (expr != KVS_STATUS_SUCCESS) { \ + LOG_ERROR(str); \ + return ATL_STATUS_FAILURE; \ + } \ + } while (0) + +#define ATL_SET_STR(dst, size, ...) \ + do { \ + if (snprintf(dst, size, __VA_ARGS__) > size) { \ + printf("line too long (must be shorter %d)\n", size); \ + printf(__VA_ARGS__); \ + return ATL_STATUS_FAILURE; \ + } \ + } while (0) + #define ATL_CALL(func, err_action) \ do { \ atl_status_t status = func; \ if (status != FI_SUCCESS) { \ - CCL_THROW(#func "\n fails with status: ", status); \ + LOG_ERROR(#func "\n fails with status: ", status); \ err_action; \ } \ } while (0) @@ -107,6 +138,13 @@ typedef enum { } atl_reduction_t; typedef enum { ATL_MNIC_NONE, ATL_MNIC_LOCAL, ATL_MNIC_GLOBAL } atl_mnic_t; +typedef enum { ATL_MNIC_OFFSET_NONE, ATL_MNIC_OFFSET_LOCAL_PROC_IDX } atl_mnic_offset_t; + +extern std::map mnic_type_names; +extern std::map mnic_offset_names; + +std::string to_string(atl_mnic_t type); +std::string to_string(atl_mnic_offset_t offset); typedef struct { struct { @@ -119,6 +157,7 @@ typedef struct { atl_mnic_t mnic_type; std::string mnic_name; size_t mnic_count; + atl_mnic_offset_t mnic_offset; } in; struct { int enable_shm; @@ -147,17 +186,18 @@ typedef struct { size_t hostname_hash; } atl_proc_coord_t; -typedef struct { - uint64_t tag; - size_t remote_proc_idx; +typedef struct atl_req { + int is_completed; void* internal[ATL_REQ_SIZE]; -} atl_req_t __attribute__((aligned(ATL_CACHELINE_LEN))); + atl_req() : is_completed(0) { + memset(internal, 0, ATL_REQ_SIZE * sizeof(void*)); + } +} atl_req_t; struct atl_ctx { atl_proc_coord_t coord; size_t ep_count; - atl_ep_t** eps; }; /* diff --git a/src/atl/atl_wrapper.cpp b/src/atl/atl_wrapper.cpp deleted file mode 100644 index 00eba6f9d..000000000 --- a/src/atl/atl_wrapper.cpp +++ /dev/null @@ -1,239 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h" -#include "atl/util/pm/pmi_rt/pmi_simple.h" -#include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" -#include "atl/util/pm/pmi_resizable_rt/pmi_resizable.h" -#include "atl/ofi/atl_ofi.hpp" -#ifdef CCL_ENABLE_MPI -#include "atl/mpi/atl_mpi.hpp" -#endif // CCL_ENABLE_MPI -#include "atl_wrapper.h" -#include "common/global/global.hpp" -#include "exec/exec.hpp" -#include "util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.h" - -static std::list> transports{}; -static ccl_executor* executor; - -atl_attr_t atl_wrapper::attr = { - /* in */ - { - 0, /* enable_shm */ - 0, /* enable_rma */ - 0, /* enable_hmem */ - 0, /* enable_sync_coll */ - 0, /* enable_extra_ep */ - 1, /* ep_count */ - ATL_MNIC_NONE, /* mnic_type */ - "", /* mnic_name */ - 1 /* mnic_count */ - }, - - /* out */ - {} -}; - -void atl_wrapper::set_internal_env(const atl_attr_t& attr) { - auto transport_type = ccl::global_data::env().atl_transport; - - if (transport_type == ccl_atl_ofi) - atl_ofi::atl_set_env(attr); -#ifdef CCL_ENABLE_MPI - else if (transport_type == ccl_atl_mpi) - atl_mpi::atl_set_env(attr); -#endif // CCL_ENABLE_MPI -} - -void atl_wrapper::set_exec(ccl_executor* exec) { - executor = exec; -} - -atl_wrapper::atl_wrapper() { - auto transport_type = ccl::global_data::env().atl_transport; - - char* pm_type_str; - switch (transport_type) { - case ccl_atl_ofi: - pm_type_str = getenv(PM_TYPE); - if (pm_type_str) { - if (strstr(pm_type_str, PM_RT_VAL_SIMPLE)) { - pmi = std::unique_ptr(new pmi_simple()); - } - else if (strstr(pm_type_str, PM_RT_VAL_RESIZABLE)) { - std::shared_ptr k(new internal_kvs()); - pmi = std::unique_ptr(new pmi_resizable(k)); - } - else { - LOG_ERROR("Unknown %s: %s\n", PM_TYPE, pm_type_str); - } - } - else { - pmi = std::unique_ptr(new pmi_simple()); - } - transport = std::shared_ptr(new atl_ofi()); - break; -#ifdef CCL_ENABLE_MPI - case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; -#endif // CCL_ENABLE_MPI - default: LOG_ERROR("Unsupported yet"); break; - } - - init_transport(); -} - -atl_wrapper::atl_wrapper(std::shared_ptr k) { - auto transport_type = ccl::global_data::env().atl_transport; - - char* pm_type_str; - switch (transport_type) { - case ccl_atl_ofi: - pm_type_str = getenv(PM_TYPE); - if (pm_type_str) { - if (strstr(pm_type_str, PM_RT_VAL_SIMPLE)) { - pmi = std::unique_ptr(new pmi_simple()); - } - else if (strstr(pm_type_str, PM_RT_VAL_RESIZABLE)) { - pmi = std::unique_ptr(new pmi_resizable(k)); - } - else { - LOG_ERROR("Unknown %s: %s\n", PM_TYPE, pm_type_str); - } - } - else { - pmi = std::unique_ptr(new pmi_simple()); - } - transport = std::shared_ptr(new atl_ofi()); - break; -#ifdef CCL_ENABLE_MPI - case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; -#endif // CCL_ENABLE_MPI - default: LOG_ERROR("Unsupported yet"); break; - } - - init_transport(); -} - -atl_wrapper::atl_wrapper(int total_rank_count, - const std::vector& ranks, - std::shared_ptr k) { - auto transport_type = ccl::global_data::env().atl_transport; - - switch (transport_type) { - case ccl_atl_ofi: { - size_t transorts_count = transports.size(); - std::shared_ptr kvs; - if ((kvs = std::dynamic_pointer_cast(k)) != nullptr) { - pmi = std::unique_ptr( - new pmi_resizable_simple_internal(total_rank_count, ranks, kvs)); - } - else { - pmi = std::unique_ptr(new pmi_resizable_simple(total_rank_count, ranks, k)); - } - - if (pmi->get_local_thread_idx() == 0) { - transports.push_back(std::shared_ptr(new atl_ofi())); - } - //TODO: Rework it on barrier - while (transorts_count == transports.size()) { - ccl_yield(ccl::global_data::env().yield_type); - } - static std::mutex memory_mutex; - { - std::lock_guard lock(memory_mutex); - transport = transports.back(); - } - } break; -#ifdef CCL_ENABLE_MPI - case ccl_atl_mpi: transport = std::shared_ptr(new atl_mpi()); break; -#endif // CCL_ENABLE_MPI - default: LOG_ERROR("Unsupported yet"); break; - } - - init_transport(); -} -void atl_wrapper::init_transport() { - LOG_DEBUG("init ATL, requested ep_count ", attr.in.ep_count); - static std::mutex memory_mutex; - { - std::lock_guard lock(memory_mutex); - if (!transport->is_inited()) { - CCL_THROW_IF_NOT( - transport->atl_init(nullptr, nullptr, &attr, nullptr, pmi) == ATL_STATUS_SUCCESS, - "failed to initialize ATL"); - } - } - eps = transport->atl_get_eps(); - tag = std::unique_ptr(new ccl_atl_tag(attr.out.tag_bits, attr.out.max_tag)); - - if (pmi) { - threads_per_process = pmi->get_threads_per_process(); - ranks_per_process = pmi->get_ranks_per_process(); - rank = pmi->get_rank(); - size = pmi->get_size(); - } -#ifdef CCL_ENABLE_MPI - else { - threads_per_process = 1; - ranks_per_process = 1; - rank = static_cast(transport.get())->get_rank(); - size = static_cast(transport.get())->get_size(); - } -#endif // CCL_ENABLE_MPI - - if (rank == 0) { - tag->print(); - LOG_INFO("atl-in-attrs:"); - LOG_INFO(" enable_shm: ", attr.in.enable_shm); - LOG_INFO(" enable_rma: ", attr.in.enable_rma); - LOG_INFO(" enable_hmem: ", attr.in.enable_hmem); - LOG_INFO(" enable_sync_coll: ", attr.in.enable_sync_coll); - LOG_INFO(" enable_extra_ep: ", attr.in.enable_extra_ep); - LOG_INFO(" ep_count: ", attr.in.ep_count); - LOG_INFO(" mnic_type: ", attr.in.mnic_type); - LOG_INFO(" mnic_count: ", attr.in.mnic_count); - - LOG_INFO("atl-out-attrs:"); - LOG_INFO(" enable_shm: ", attr.out.enable_shm); - LOG_INFO(" enable_rma: ", attr.out.enable_rma); - LOG_INFO(" enable_hmem: ", attr.out.enable_hmem); - LOG_INFO(" mnic_type: ", attr.out.mnic_type); - LOG_INFO(" mnic_count: ", attr.out.mnic_count); - LOG_INFO(" tag_bits: ", attr.out.tag_bits); - LOG_INFO(" max_tag: ", attr.out.max_tag); - LOG_INFO(" max_order_waw_size: ", attr.out.max_order_waw_size); - } - - if ((!pmi) || (pmi && pmi->get_local_thread_idx() == 0)) { - if (!executor->are_workers_started()) { - atl_proc_coord_t* coord = atl_get_proc_coord(); - if (rank < coord->local_count) - LOG_INFO("start workers for local process [", - coord->local_idx, - ":", - coord->local_count, - "]"); - executor->start_workers(coord->local_idx, coord->local_count); - } - } -} - -atl_wrapper::~atl_wrapper() { - static std::mutex memory_mutex; - std::lock_guard lock(memory_mutex); - transports.remove(transport); - tag.reset(); -} diff --git a/src/atl/atl_wrapper.h b/src/atl/atl_wrapper.h deleted file mode 100644 index 4c5ec38e1..000000000 --- a/src/atl/atl_wrapper.h +++ /dev/null @@ -1,279 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include -#include -#include -#include - -#include "atl.h" -#include "common/comm/atl_tag.hpp" -#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h" -#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" -#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h" - -class ccl_executor; - -class atl_wrapper { -public: - static void set_internal_env(const atl_attr_t& attr); - static void set_exec(ccl_executor* exec); - - ~atl_wrapper(); - atl_wrapper(); - atl_wrapper(std::shared_ptr k); - atl_wrapper(int total_rank_count, - const std::vector& ranks, - std::shared_ptr k); - - atl_status_t atl_main_addr_reserve(char* main_addr) { - if (!pmi) - return ATL_STATUS_UNSUPPORTED; - - return pmi->pmrt_main_addr_reserve(main_addr); - ; - } - - atl_status_t atl_finalize() { - if (pmi) - pmi->pmrt_finalize(); - - return transport->atl_finalize(); - } - - atl_status_t atl_update() { - return transport->atl_update(pmi); - } - - atl_status_t atl_wait_notification() { - if (!pmi) - return ATL_STATUS_UNSUPPORTED; - - return pmi->pmrt_wait_notification(); - } - - atl_status_t atl_set_resize_function(atl_resize_fn_t fn) { - if (!pmi) - return ATL_STATUS_UNSUPPORTED; - - return pmi->pmrt_set_resize_function(fn); - } - - atl_proc_coord_t* atl_get_proc_coord() { - return transport->atl_get_proc_coord(); - } - - atl_status_t atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) { - return transport->atl_mr_reg(buf, len, mr); - } - - atl_status_t atl_mr_dereg(atl_mr_t* mr) { - return transport->atl_mr_dereg(mr); - } - - atl_status_t atl_ep_send(size_t ep_idx, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) { - return transport->atl_ep_send(eps[ep_idx], buf, len, dst_proc_idx, tag, req); - } - - atl_status_t atl_ep_recv(size_t ep_idx, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) { - return transport->atl_ep_recv(eps[ep_idx], buf, len, src_proc_idx, tag, req); - } - - atl_status_t atl_ep_probe(size_t ep_idx, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) { - return transport->atl_ep_probe(eps[ep_idx], src_proc_idx, tag, found, recv_len); - } - - atl_status_t atl_ep_allgatherv(size_t ep_idx, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) { - return transport->atl_ep_allgatherv( - eps[ep_idx], send_buf, send_len, recv_buf, recv_lens, offsets, req); - } - - atl_status_t atl_ep_allreduce(size_t ep_idx, - const void* send_buf, - void* recv_buf, - size_t len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return transport->atl_ep_allreduce(eps[ep_idx], send_buf, recv_buf, len, dtype, op, req); - } - - atl_status_t atl_ep_alltoall(size_t ep_idx, - const void* send_buf, - void* recv_buf, - int len, - atl_req_t* req) { - return transport->atl_ep_alltoall(eps[ep_idx], send_buf, recv_buf, len, req); - } - - atl_status_t atl_ep_alltoallv(size_t ep_idx, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) { - return transport->atl_ep_alltoallv( - eps[ep_idx], send_buf, send_lens, send_offsets, recv_buf, recv_lens, recv_offsets, req); - } - - atl_status_t atl_ep_barrier(size_t ep_idx, atl_req_t* req) { - return transport->atl_ep_barrier(eps[ep_idx], req); - } - - atl_status_t atl_ep_bcast(size_t ep_idx, void* buf, size_t len, int root, atl_req_t* req) { - return transport->atl_ep_bcast(eps[ep_idx], buf, len, root, req); - } - - atl_status_t atl_ep_reduce(size_t ep_idx, - const void* send_buf, - void* recv_buf, - size_t len, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return transport->atl_ep_reduce(eps[ep_idx], send_buf, recv_buf, len, root, dtype, op, req); - } - - atl_status_t atl_ep_reduce_scatter(size_t ep_idx, - const void* send_buf, - void* recv_buf, - size_t recv_len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return transport->atl_ep_reduce_scatter( - eps[ep_idx], send_buf, recv_buf, recv_len, dtype, op, req); - } - - atl_status_t atl_ep_read(size_t ep_idx, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { - return transport->atl_ep_read( - eps[ep_idx], buf, len, mr, addr, remote_key, dst_proc_idx, req); - } - - atl_status_t atl_ep_write(size_t ep_idx, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { - return transport->atl_ep_write( - eps[ep_idx], buf, len, mr, addr, remote_key, dst_proc_idx, req); - } - - atl_status_t atl_ep_wait(size_t ep_idx, atl_req_t* req) { - return transport->atl_ep_wait(eps[ep_idx], req); - } - - atl_status_t atl_ep_wait_all(size_t ep_idx, atl_req_t* req, size_t count) { - return transport->atl_ep_wait_all(eps[ep_idx], req, count); - } - - atl_status_t atl_ep_cancel(size_t ep_idx, atl_req_t* req) { - return transport->atl_ep_cancel(eps[ep_idx], req); - } - - atl_status_t atl_ep_poll(size_t ep_idx) { - return transport->atl_ep_poll(eps[ep_idx]); - } - - atl_status_t atl_ep_check(size_t ep_idx, int* is_completed, atl_req_t* req) { - return transport->atl_ep_check(eps[ep_idx], is_completed, req); - } - - size_t get_threads_per_process() { - return threads_per_process; - } - - size_t get_ranks_per_process() { - return ranks_per_process; - } - - int get_rank() { - return rank; - } - - int get_size() { - return size; - } - - int get_r2r_color() { - return transport->atl_get_proc_coord()->local_idx; - } - - int get_host_color() { - return transport->atl_get_proc_coord()->hostname_hash; - } - - /* - * TODO: Temporary change. - * Need to define correct to unique id - */ - size_t get_id() { - return 0; - } - - /* static ATL attr for all transport instances - actual values generated by executor */ - static atl_attr_t attr; - - std::unique_ptr tag; - -private: - int rank; - int size; - - size_t threads_per_process; - size_t ranks_per_process; - - std::shared_ptr transport; - std::unique_ptr pmi; - atl_ep_t** eps = nullptr; - - void init_transport(); -}; diff --git a/src/atl/mpi/atl_mpi.cpp b/src/atl/mpi/atl_mpi.cpp index 2e742d80e..0670b44f6 100644 --- a/src/atl/mpi/atl_mpi.cpp +++ b/src/atl/mpi/atl_mpi.cpp @@ -15,187 +15,838 @@ */ #ifdef CCL_ENABLE_MPI +#include "atl_def.h" #include "atl_mpi.hpp" -#include "atl_mpi_impl.cpp" -atl_status_t atl_mpi::atl_set_env(const atl_attr_t& attr) { - return atl_mpi_set_env(attr); -} +#define MPI_BFLOAT16 \ + ({ \ + CCL_THROW_IF_NOT(global_data.bf16.dtype != MPI_DATATYPE_NULL, \ + "unsupported datatype: ATL_DTYPE_BF16"); \ + global_data.bf16.dtype; \ + }) + +#define MPI_FLOAT16 \ + ({ \ + CCL_THROW_IF_NOT(global_data.fp16.dtype != MPI_DATATYPE_NULL, \ + "unsupported datatype: ATL_DTYPE_FP16"); \ + global_data.fp16.dtype; \ + }) + +#define RET2ATL(ret) (ret != MPI_SUCCESS) ? ATL_STATUS_FAILURE : ATL_STATUS_SUCCESS -atl_status_t atl_mpi::atl_init(int* argc, - char*** argv, - atl_attr_t* attr, - const char* main_addr, - std::unique_ptr& pmi) { +atl_mpi_global_data atl_mpi::global_data{}; + +atl_status_t atl_mpi::init(int* argc, + char*** argv, + atl_attr_t* attr, + const char* main_addr, + std::shared_ptr pmi) { inited = true; - return atl_mpi_init(argc, argv, attr, &ctx, main_addr, pmi.get()); + CCL_THROW_IF_NOT((sizeof(atl_mpi_req_t) <= sizeof(atl_req_t) - offsetof(atl_req_t, internal)), + "unexpected offset: atl_mpi_request size ", + sizeof(atl_mpi_req_t), + ", atl_request size ", + sizeof(atl_req_t), + ", expected offset ", + offsetof(atl_req_t, internal)); + + int ret = MPI_SUCCESS; + int is_tag_ub_set = 0; + void* tag_ub_ptr = NULL; + int required_thread_level = MPI_THREAD_MULTIPLE, provided_thread_level; + + if (global_data.ctx_count == 0) { + if (global_data.set_env(*attr)) { + goto err_init; + } + + MPI_Initialized(&global_data.is_external_init); + + if (!global_data.is_external_init) { + ret = MPI_Init_thread(argc, argv, required_thread_level, &provided_thread_level); + if (provided_thread_level < required_thread_level) { + LOG_ERROR("unexpected MPI thread level: required ", + required_thread_level, + ", provided ", + provided_thread_level); + goto err_init; + } + } + else { + LOG_DEBUG("MPI was initialized externaly"); + MPI_Query_thread(&provided_thread_level); + if (provided_thread_level < required_thread_level) { + LOG_WARN("MPI was initialized externaly but with unexpected thread level: " + "required ", + required_thread_level, + ", provided ", + provided_thread_level); + } + } + + if (ret) + goto err_init; + + if (global_data.update_global_data(attr) == ATL_STATUS_FAILURE) { + goto err_init; + } + } + global_data.ctx_count++; + + coord_update(MPI_COMM_WORLD, global_coord); + + ep_count = attr->in.ep_count; + + char* progress_mode_env; + progress_mode_env = getenv(ATL_PROGRESS_MODE_ENV); + if (progress_mode_env) { + progress_mode = (atl_progress_mode_t)atoi(progress_mode_env); + } + else { + progress_mode = ATL_PROGRESS_CHECK; + } + sync_coll = attr->in.enable_sync_coll; + + if (global_coord.global_idx == 0) { + global_data.print_log_info(); + LOG_INFO("atl-mpi-ctx: ", (global_data.ctx_count - 1)); + LOG_INFO(" progress_mode: ", progress_mode); + LOG_INFO(" sync_coll: ", sync_coll); + } + + MPI_Comm_get_attr(MPI_COMM_WORLD, MPI_TAG_UB, &tag_ub_ptr, &is_tag_ub_set); + + /* report actual attributes back to upper level */ + attr->out.enable_shm = 0; + attr->out.enable_rma = 0; + attr->out.enable_hmem = attr->in.enable_hmem & global_data.mpi_lib_attr.hmem; + attr->out.mnic_type = global_data.mnic_type; + attr->out.mnic_count = global_data.mnic_count; + attr->out.tag_bits = 32; + attr->out.max_tag = (is_tag_ub_set) ? *((int*)tag_ub_ptr) : 0; + attr->out.max_order_waw_size = 0; + + return ATL_STATUS_SUCCESS; + +err_init: + return ATL_STATUS_FAILURE; +} + +void atl_mpi::coord_update(MPI_Comm base_comm, atl_proc_coord_t& coord) { + MPI_Comm_rank(base_comm, (int*)&(coord.global_idx)); + MPI_Comm_size(base_comm, (int*)&(coord.global_count)); + + MPI_Comm local_comm; + MPI_Comm_split_type( + base_comm, MPI_COMM_TYPE_SHARED, coord.global_count, MPI_INFO_NULL, &local_comm); + MPI_Comm_rank(local_comm, (int*)&(coord.local_idx)); + MPI_Comm_size(local_comm, (int*)&(coord.local_count)); + MPI_Comm_free(&local_comm); + + char my_hostname[ATL_MAX_HOSTNAME_LEN] = { 0 }; + gethostname(my_hostname, ATL_MAX_HOSTNAME_LEN - 1); + coord.hostname_hash = std::hash{}(my_hostname); +} + +void atl_mpi::comms_free(std::vector& eps) { + for (size_t i = 0; i < eps.size(); i++) { + atl_mpi_ep_t& mpi_ep = eps[i]; + + if (progress_mode == ATL_PROGRESS_POLL) { + MPI_Cancel(&(mpi_ep.dummy_req.native_req)); + MPI_Comm_free(&mpi_ep.dummy_comm); + } + MPI_Comm_free(&mpi_ep.mpi_comm); + } } -atl_status_t atl_mpi::atl_finalize() { +atl_status_t atl_mpi::finalize() { is_finalized = true; - return atl_mpi_finalize(ctx); + + int ret = MPI_SUCCESS; + + global_data.ctx_count--; + if (global_coord.global_idx == 0) { + LOG_INFO("finalize atl-mpi ctx, remaining ctx_count ", global_data.ctx_count); + } + + int is_mpi_finalized = 0; + MPI_Finalized(&is_mpi_finalized); + + if (!is_mpi_finalized) { + if (global_data.ctx_count == 0) { + global_data.bf16_finalize(); + global_data.fp16_finalize(); + if (!global_data.is_external_init) { + ret = MPI_Finalize(); + } + else { + LOG_DEBUG("MPI_Init has been called externally, skip MPI_Finalize"); + } + + if (global_coord.global_idx == 0) { + LOG_INFO("finalized last atl-mpi ctx"); + } + } + } + else { + if ((global_data.ctx_count == 0) && (global_coord.global_idx == 0)) { + LOG_WARN("MPI_Finalize has been called before CCL finalization"); + } + } + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_update(std::unique_ptr& pmi) { +atl_status_t atl_mpi::update(std::shared_ptr pmi) { (void)pmi; return ATL_STATUS_UNSUPPORTED; } -atl_ep_t** atl_mpi::atl_get_eps() { - return ctx->eps; +atl_status_t atl_mpi::mr_reg(const void* buf, size_t len, atl_mr_t** mr) { + return ATL_STATUS_UNSUPPORTED; } -atl_proc_coord_t* atl_mpi::atl_get_proc_coord() { - return &(ctx->coord); +atl_status_t atl_mpi::mr_dereg(atl_mr_t* mr) { + return ATL_STATUS_UNSUPPORTED; } -atl_status_t atl_mpi::atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) { - return atl_mpi_mr_reg(ctx, buf, len, mr); -} +atl_status_t atl_mpi::send(atl_mpi_ep_t& ep, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req) { + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); -atl_status_t atl_mpi::atl_mr_dereg(atl_mr_t* mr) { - return atl_mpi_mr_dereg(ctx, mr); -} + init_req(req); + + int ret = + MPI_Isend(buf, len, MPI_CHAR, dst_proc_idx, (int)tag, ep.mpi_comm, &mpi_req->native_req); -atl_status_t atl_mpi::atl_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) { - return atl_mpi_ep_send(ep, buf, len, dst_proc_idx, tag, req); + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) { - return atl_mpi_ep_recv(ep, buf, len, src_proc_idx, tag, req); +atl_status_t atl_mpi::recv(atl_mpi_ep_t& ep, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req) { + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + init_req(req); + + int ret = + MPI_Irecv(buf, len, MPI_CHAR, src_proc_idx, (int)tag, ep.mpi_comm, &mpi_req->native_req); + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) { - return atl_mpi_ep_probe(ep, src_proc_idx, tag, found, recv_len); +atl_status_t atl_mpi::probe(atl_mpi_ep_t& ep, + int src_proc_idx, + uint64_t tag, + int* found, + size_t* recv_len) { + int flag = 0, len = 0, ret; + MPI_Status status; + + ret = MPI_Iprobe(src_proc_idx, tag, ep.mpi_comm, &flag, &status); + if (flag) { + MPI_Get_count(&status, MPI_BYTE, &len); + } + + if (found) + *found = flag; + if (recv_len) + *recv_len = len; + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) { - return atl_mpi_ep_allgatherv(ep, send_buf, send_len, recv_buf, recv_lens, offsets, req); +atl_status_t atl_mpi::allgatherv(atl_mpi_ep_t& ep, + const void* send_buf, + size_t send_len, + void* recv_buf, + const int* recv_lens, + const int* offsets, + atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + init_req(req); + + if (sync_coll) { + ret = MPI_Allgatherv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + send_len, + MPI_CHAR, + recv_buf, + recv_lens, + offsets, + MPI_CHAR, + ep.mpi_comm); + } + else { + ret = MPI_Iallgatherv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + send_len, + MPI_CHAR, + recv_buf, + recv_lens, + offsets, + MPI_CHAR, + ep.mpi_comm, + &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return atl_mpi_ep_allreduce(ep, send_buf, recv_buf, len, dtype, op, req); +atl_status_t atl_mpi::allreduce(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + size_t len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); + MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); + + init_req(req); + + if (sync_coll) { + ret = MPI_Allreduce((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + recv_buf, + len, + mpi_dtype, + mpi_op, + ep.mpi_comm); + } + else { + //printf("atl_mpi: send_buf %p, recv_buf %p\n", send_buf, recv_buf); + ret = MPI_Iallreduce((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + recv_buf, + len, + mpi_dtype, + mpi_op, + ep.mpi_comm, + &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - int len, - atl_req_t* req) { - return atl_mpi_ep_alltoall(ep, send_buf, recv_buf, len, req); +atl_status_t atl_mpi::alltoall(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + int len, + atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + init_req(req); + + if (sync_coll) { + ret = MPI_Alltoall((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + len, + MPI_CHAR, + recv_buf, + len, + MPI_CHAR, + ep.mpi_comm); + } + else { + ret = MPI_Ialltoall((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + len, + MPI_CHAR, + recv_buf, + len, + MPI_CHAR, + ep.mpi_comm, + &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) { - return atl_mpi_ep_alltoallv( - ep, send_buf, send_lens, send_offsets, recv_buf, recv_lens, recv_offsets, req); +atl_status_t atl_mpi::alltoallv(atl_mpi_ep_t& ep, + const void* send_buf, + const int* send_lens, + const int* send_offsets, + void* recv_buf, + const int* recv_lens, + const int* recv_offsets, + atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + init_req(req); + + if (sync_coll) { + ret = MPI_Alltoallv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + send_lens, + send_offsets, + MPI_CHAR, + recv_buf, + recv_lens, + recv_offsets, + MPI_CHAR, + ep.mpi_comm); + } + else { + ret = MPI_Ialltoallv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + send_lens, + send_offsets, + MPI_CHAR, + recv_buf, + recv_lens, + recv_offsets, + MPI_CHAR, + ep.mpi_comm, + &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) { - return atl_mpi_ep_barrier(ep, req); +atl_status_t atl_mpi::barrier(atl_mpi_ep_t& ep, atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + init_req(req); + + if (sync_coll) { + ret = MPI_Barrier(ep.mpi_comm); + } + else { + ret = MPI_Ibarrier(ep.mpi_comm, &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req) { - return atl_mpi_ep_bcast(ep, buf, len, root, req); +atl_status_t atl_mpi::bcast(atl_mpi_ep_t& ep, void* buf, size_t len, int root, atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + init_req(req); + + if (sync_coll) { + ret = MPI_Bcast(buf, len, MPI_CHAR, root, ep.mpi_comm); + } + else { + ret = MPI_Ibcast(buf, len, MPI_CHAR, root, ep.mpi_comm, &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return atl_mpi_ep_reduce(ep, send_buf, recv_buf, len, root, dtype, op, req); +atl_status_t atl_mpi::reduce(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + size_t len, + int root, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + int my_proc_idx = ep.coord->global_idx; + MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); + MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); + + init_req(req); + + if (sync_coll) { + ret = MPI_Reduce( + (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, + recv_buf, + len, + mpi_dtype, + mpi_op, + root, + ep.mpi_comm); + } + else { + ret = MPI_Ireduce( + (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, + recv_buf, + len, + mpi_dtype, + mpi_op, + root, + ep.mpi_comm, + &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return atl_mpi_ep_reduce_scatter(ep, send_buf, recv_buf, recv_len, dtype, op, req); +atl_status_t atl_mpi::reduce_scatter(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + size_t recv_len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) { + int ret = MPI_SUCCESS; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); + MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); + + init_req(req); + + if (sync_coll) { + ret = + MPI_Reduce_scatter_block((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + recv_buf, + recv_len, + mpi_dtype, + mpi_op, + ep.mpi_comm); + } + else { + ret = MPI_Ireduce_scatter_block( + (send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, + recv_buf, + recv_len, + mpi_dtype, + mpi_op, + ep.mpi_comm, + &mpi_req->native_req); + } + + check_ep(ep); + + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { - return atl_mpi_ep_read(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); +atl_status_t atl_mpi::read(atl_mpi_ep_t& ep, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) { + return ATL_STATUS_UNSUPPORTED; } -atl_status_t atl_mpi::atl_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { - return atl_mpi_ep_write(ep, buf, len, mr, addr, remote_key, dst_proc_idx, req); +atl_status_t atl_mpi::write(atl_mpi_ep_t& ep, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) { + return ATL_STATUS_UNSUPPORTED; } -atl_status_t atl_mpi::atl_ep_wait(atl_ep_t* ep, atl_req_t* req) { - return atl_mpi_ep_wait(ep, req); +atl_status_t atl_mpi::wait(atl_mpi_ep_t& ep, atl_req_t* req) { + int ret; + MPI_Status status; + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + ret = MPI_Wait(&mpi_req->native_req, &status); + mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; + return RET2ATL(ret); } -atl_status_t atl_mpi::atl_ep_wait_all(atl_ep_t* ep, atl_req_t* req, size_t count) { - return atl_mpi_ep_wait_all(ep, req, count); +atl_status_t atl_mpi::wait_all(atl_mpi_ep_t& ep, atl_req_t* req, size_t count) { + return ATL_STATUS_UNSUPPORTED; } -atl_status_t atl_mpi::atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) { +atl_status_t atl_mpi::cancel(atl_mpi_ep_t& ep, atl_req_t* req) { return ATL_STATUS_UNSUPPORTED; } -atl_status_t atl_mpi::atl_ep_poll(atl_ep_t* ep) { - return atl_mpi_ep_poll(ep); +atl_status_t atl_mpi::poll(atl_mpi_ep_t& ep) { + if (progress_mode == ATL_PROGRESS_POLL) { + return ep_progress(ep, &(ep.dummy_req)); + } + + return ATL_STATUS_SUCCESS; } -atl_status_t atl_mpi::atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) { - return atl_mpi_ep_check(ep, is_completed, req); +atl_status_t atl_mpi::check(atl_mpi_ep_t& ep, atl_req_t* req) { + atl_status_t status; + + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + + CCL_THROW_IF_NOT(!req->is_completed, "request is already completed"); + CCL_THROW_IF_NOT(mpi_req->comp_state == ATL_MPI_COMP_POSTED, "request is already completed"); + + if (mpi_req->native_req == MPI_REQUEST_NULL) { + mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; + } + + req->is_completed = (mpi_req->comp_state == ATL_MPI_COMP_COMPLETED); + if (req->is_completed) { + return ATL_STATUS_SUCCESS; + } + + status = ep_progress(ep, mpi_req); + req->is_completed = (mpi_req->comp_state == ATL_MPI_COMP_COMPLETED); + + return status; } + atl_mpi::~atl_mpi() { if (!is_finalized) - atl_finalize(); + finalize(); +} + +MPI_Datatype atl_mpi::atl2mpi_dtype(atl_datatype_t dtype) { + switch (dtype) { + case ATL_DTYPE_INT8: return MPI_CHAR; + case ATL_DTYPE_UINT8: return MPI_UNSIGNED_CHAR; + case ATL_DTYPE_INT16: return MPI_INT16_T; + case ATL_DTYPE_UINT16: return MPI_UINT16_T; + case ATL_DTYPE_INT32: return MPI_INT; + case ATL_DTYPE_UINT32: return MPI_UINT32_T; + case ATL_DTYPE_INT64: return MPI_LONG_LONG; + case ATL_DTYPE_UINT64: return MPI_UNSIGNED_LONG_LONG; + case ATL_DTYPE_FLOAT16: return MPI_FLOAT16; + case ATL_DTYPE_FLOAT32: return MPI_FLOAT; + case ATL_DTYPE_FLOAT64: return MPI_DOUBLE; + case ATL_DTYPE_BFLOAT16: return MPI_BFLOAT16; + default: printf("unknown datatype: %d\n", dtype); exit(1); + } +} + +inline atl_status_t atl_mpi::ep_progress(atl_mpi_ep_t& ep, atl_mpi_req_t* req) { + int flag = 0; + int ret = MPI_Test(&req->native_req, &flag, MPI_STATUS_IGNORE); + + if (flag) { + req->comp_state = ATL_MPI_COMP_COMPLETED; + } + + return RET2ATL(ret); +} + +void atl_mpi::init_req(atl_req_t* req) { + atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); + mpi_req->native_req = MPI_REQUEST_NULL; + mpi_req->comp_state = ATL_MPI_COMP_POSTED; + req->is_completed = 0; +} + +MPI_Op atl_mpi::atl2mpi_op(atl_reduction_t rtype, MPI_Datatype dtype) { +#ifdef ATL_MPI_BF16 + if (dtype == global_data.bf16.dtype) + return global_data.atl2mpi_op_bf16(rtype); +#endif // ATL_MPI_BF16 + +#ifdef ATL_MPI_FP16 + if (dtype == global_data.fp16.dtype) + return global_data.atl2mpi_op_fp16(rtype); +#endif // ATL_MPI_FP16 + + (void)dtype; + switch (rtype) { + case ATL_REDUCTION_SUM: return MPI_SUM; + case ATL_REDUCTION_PROD: return MPI_PROD; + case ATL_REDUCTION_MIN: return MPI_MIN; + case ATL_REDUCTION_MAX: return MPI_MAX; + default: printf("unknown reduction type: %d\n", rtype); exit(1); + } +} + +size_t atl_mpi::get_ep_idx(size_t ep_idx) { + size_t mpi_ep_idx = ep_idx; + if (global_data.extra_ep) + mpi_ep_idx += global_data.extra_ep; + return mpi_ep_idx; +} + +atl_status_t atl_mpi::ep_init(std::vector& eps) { + atl_mpi_ep_t base_ep; + base_ep.mpi_comm = MPI_COMM_WORLD; + base_ep.dummy_comm = MPI_COMM_WORLD; + base_ep.idx = 0; + base_ep.coord = nullptr; + std::vector base_eps(ep_count, base_ep); + return comm_split(base_eps, eps, 0); +} + +#ifdef ENABLE_DEBUG +void atl_mpi::check_ep(atl_mpi_ep_t& ep) { + check_comm_ep_idx(ep.mpi_comm, get_ep_idx(ep.idx)); +} +#endif // ENABLE_DEBUG + +void atl_mpi::check_comm_nic_idx(MPI_Comm comm, size_t expected_idx) { + char expected_idx_str[MPI_MAX_INFO_VAL] = { 0 }; + snprintf(expected_idx_str, MPI_MAX_INFO_VAL, "%zu", expected_idx); + check_comm_info(comm, global_data.NIC_IDX_KEY, expected_idx_str); +} + +void atl_mpi::check_comm_ep_idx(MPI_Comm comm, size_t expected_idx) { + if (global_data.mpi_lib_attr.type == global_data.ATL_MPI_LIB_NONE) + return; + + char expected_idx_str[MPI_MAX_INFO_VAL] = { 0 }; + snprintf(expected_idx_str, MPI_MAX_INFO_VAL, "%zu", expected_idx); + check_comm_info(comm, global_data.EP_IDX_KEY, expected_idx_str); +} + +void atl_mpi::check_comm_info(MPI_Comm comm, const char* key, const char* expected_value) { + atl_mpi_comm_info_t info = atl_mpi::get_comm_info(comm, key); + + CCL_THROW_IF_NOT(info.found, "MPI comm key ", key, " was not set"); + CCL_THROW_IF_NOT(!strcmp(info.value, expected_value), + "MPI comm key ", + key, + ": expected: ", + expected_value, + ", read: ", + info.value); +} + +void atl_mpi::set_env(const atl_attr_t& attr) { + global_data.set_env(attr); +} + +atl_status_t atl_mpi::comm_split(const std::vector& base_eps, + std::vector& eps, + size_t color) { + int ret; + atl_mpi_ep_t ep; + for (size_t idx = 0; idx < ep_count; idx++) { + ssize_t mpi_ep_idx = get_ep_idx(idx); + char mpi_ep_idx_str[MPI_MAX_INFO_VAL] = { 0 }; + + size_t nic_idx = 0; + char nic_idx_str[MPI_MAX_INFO_VAL] = { 0 }; + + ret = MPI_Comm_split(base_eps[idx].mpi_comm, color, 0, &ep.mpi_comm); + if (ret) { + LOG_ERROR("MPI_Comm_split error, ep_idx ", idx); + break; + } + + MPI_Info info; + MPI_Info_create(&info); + + /* set EP index */ + snprintf(mpi_ep_idx_str, MPI_MAX_INFO_VAL, "%zu", mpi_ep_idx); + MPI_Info_set(info, global_data.EP_IDX_KEY, mpi_ep_idx_str); + + if (global_data.mnic_type != ATL_MNIC_NONE) { + /* set NIC index */ + nic_idx = idx; + if (global_data.mnic_offset == ATL_MNIC_OFFSET_LOCAL_PROC_IDX) { + nic_idx += global_coord.local_idx; + } + nic_idx %= global_data.mnic_count; + snprintf(nic_idx_str, MPI_MAX_INFO_VAL, "%zu", nic_idx); + MPI_Info_set(info, global_data.NIC_IDX_KEY, nic_idx_str); + + LOG_INFO("select nic: ep_idx ", + idx, + ", local_proc_idx ", + global_coord.local_idx, + ", nic_idx ", + nic_idx); + } + + MPI_Comm_set_info(ep.mpi_comm, info); + + if (progress_mode == ATL_PROGRESS_POLL) { + ret = MPI_Comm_split(base_eps[idx].dummy_comm, color, 0, &ep.dummy_comm); + if (ret) { + LOG_ERROR("MPI_Comm_split error, ep_idx ", idx); + break; + } + MPI_Comm_set_info(ep.dummy_comm, info); + MPI_Irecv(NULL, 0, MPI_CHAR, 0, 0, ep.dummy_comm, &(ep.dummy_req.native_req)); + + check_comm_ep_idx(ep.dummy_comm, mpi_ep_idx); + if (global_data.mnic_type != ATL_MNIC_NONE) { + check_comm_nic_idx(ep.dummy_comm, nic_idx); + } + } + + MPI_Info_free(&info); + + check_comm_ep_idx(ep.mpi_comm, mpi_ep_idx); + if (global_data.mnic_type != ATL_MNIC_NONE) { + check_comm_nic_idx(ep.mpi_comm, nic_idx); + } + + LOG_DEBUG("atl-mpi-ep: ", idx, ", ep_idx ", mpi_ep_idx, ", nic_idx ", nic_idx); + + ep.idx = idx; + eps.push_back(ep); + } + + if (ret) { + comms_free(eps); + global_data.ctx_count--; + if (global_data.ctx_count == 0) { + global_data.bf16_finalize(); + global_data.fp16_finalize(); + if (!global_data.is_external_init) { + MPI_Finalize(); + } + } + } + + return RET2ATL(ret); +} + +atl_mpi_env_info_t atl_mpi::get_env_info(const char* key) { + atl_mpi_env_info_t res; + snprintf(res.key, MPI_MAX_INFO_KEY, "%s", key); + MPI_Info_get(MPI_INFO_ENV, key, MPI_MAX_INFO_VAL, res.value, &res.found); + return res; +} + +atl_mpi_comm_info_t atl_mpi::get_comm_info(MPI_Comm comm, const char* key) { + MPI_Info info; + atl_mpi_comm_info_t res; + + res.comm = comm; + snprintf(res.key, MPI_MAX_INFO_KEY, "%s", key); + + MPI_Comm_get_info(res.comm, &info); + MPI_Info_get(info, key, MPI_MAX_INFO_VAL, res.value, &res.found); + MPI_Info_free(&info); + + return res; } #endif // CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi.hpp b/src/atl/mpi/atl_mpi.hpp index 03760a0ae..e82997e27 100644 --- a/src/atl/mpi/atl_mpi.hpp +++ b/src/atl/mpi/atl_mpi.hpp @@ -13,153 +13,212 @@ See the License for the specific language governing permissions and limitations under the License. */ +#pragma once #ifdef CCL_ENABLE_MPI +#include -#include "atl.h" +#include "atl_mpi_global_data.hpp" -class atl_mpi final : public iatl { +typedef enum { ATL_MPI_COMP_POSTED, ATL_MPI_COMP_COMPLETED } atl_mpi_comp_state_t; + +typedef struct { + MPI_Request native_req; + atl_mpi_comp_state_t comp_state; +} atl_mpi_req_t; + +typedef struct { + MPI_Comm mpi_comm; + + /* dummy recv operation to ensure progress in atl_poll */ + atl_mpi_req_t dummy_req; + MPI_Comm dummy_comm; + size_t idx; + atl_proc_coord_t* coord; +} atl_mpi_ep_t; + +typedef struct atl_mpi_env_info { + int found; + char key[MPI_MAX_INFO_KEY]; + char value[MPI_MAX_INFO_VAL]; + + atl_mpi_env_info() { + found = 0; + memset(key, 0, MPI_MAX_INFO_KEY); + memset(value, 0, MPI_MAX_INFO_VAL); + } +} atl_mpi_env_info_t; + +typedef struct atl_mpi_comm_info : atl_mpi_env_info_t { + MPI_Comm comm; + + atl_mpi_comm_info() { + comm = MPI_COMM_WORLD; + } +} atl_mpi_comm_info_t; + +class atl_mpi { public: atl_mpi() = default; - ~atl_mpi() override; - - static atl_status_t atl_set_env(const atl_attr_t& attr); - - atl_status_t atl_init(int* argc, - char*** argv, - atl_attr_t* attr, - const char* main_addr, - std::unique_ptr& pmi) override; - - atl_status_t atl_update(std::unique_ptr& pmi) override; - - atl_ep_t** atl_get_eps() override; - - atl_proc_coord_t* atl_get_proc_coord() override; - - atl_status_t atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) override; - - atl_status_t atl_mr_dereg(atl_mr_t* mr) override; - - atl_status_t atl_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) override; - - atl_status_t atl_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) override; - - atl_status_t atl_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) override; - - atl_status_t atl_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) override; - - atl_status_t atl_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) override; - - atl_status_t atl_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - int len, - atl_req_t* req) override; - - atl_status_t atl_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) override; - - atl_status_t atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) override; - - atl_status_t atl_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - int root, - atl_req_t* req) override; - - atl_status_t atl_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) override; - - atl_status_t atl_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) override; - - atl_status_t atl_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) override; - - atl_status_t atl_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) override; - - atl_status_t atl_ep_wait(atl_ep_t* ep, atl_req_t* req) override; - - atl_status_t atl_ep_wait_all(atl_ep_t* ep, atl_req_t* req, size_t count) override; - - atl_status_t atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) override; - - atl_status_t atl_ep_poll(atl_ep_t* ep) override; - - atl_status_t atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) override; - - atl_status_t atl_finalize() override; + ~atl_mpi(); + + atl_status_t init(int* argc, + char*** argv, + atl_attr_t* attr, + const char* main_addr, + std::shared_ptr pmi); + + atl_status_t update(std::shared_ptr pmi); + + atl_status_t mr_reg(const void* buf, size_t len, atl_mr_t** mr); + + atl_status_t mr_dereg(atl_mr_t* mr); + + atl_status_t send(atl_mpi_ep_t& ep, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req); + + atl_status_t recv(atl_mpi_ep_t& ep, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req); + + atl_status_t probe(atl_mpi_ep_t& ep, + int src_proc_idx, + uint64_t tag, + int* found, + size_t* recv_len); + + atl_status_t allgatherv(atl_mpi_ep_t& ep, + const void* send_buf, + size_t send_len, + void* recv_buf, + const int* recv_lens, + const int* offsets, + atl_req_t* req); + + atl_status_t allreduce(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + size_t len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req); + + atl_status_t alltoall(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + int len, + atl_req_t* req); + + atl_status_t alltoallv(atl_mpi_ep_t& ep, + const void* send_buf, + const int* send_lens, + const int* send_offsets, + void* recv_buf, + const int* recv_lens, + const int* recv_offsets, + atl_req_t* req); + + atl_status_t barrier(atl_mpi_ep_t& ep, atl_req_t* req); + + atl_status_t bcast(atl_mpi_ep_t& ep, void* buf, size_t len, int root, atl_req_t* req); + + atl_status_t reduce(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + size_t len, + int root, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req); + + atl_status_t reduce_scatter(atl_mpi_ep_t& ep, + const void* send_buf, + void* recv_buf, + size_t recv_len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req); + + atl_status_t read(atl_mpi_ep_t& ep, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req); + + atl_status_t write(atl_mpi_ep_t& ep, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req); + + atl_status_t wait(atl_mpi_ep_t& ep, atl_req_t* req); + + atl_status_t wait_all(atl_mpi_ep_t& ep, atl_req_t* req, size_t count); + + atl_status_t cancel(atl_mpi_ep_t& ep, atl_req_t* req); + + atl_status_t poll(atl_mpi_ep_t& ep); + + atl_status_t check(atl_mpi_ep_t& ep, atl_req_t* req); + + void comms_free(std::vector& eps); + + atl_status_t finalize(); int get_rank() { - return ctx->coord.global_idx; + return global_coord.global_idx; } int get_size() { - return ctx->coord.global_count; + return global_coord.global_count; } - bool is_inited() override { + bool is_inited() { return inited; } + static void set_env(const atl_attr_t& attr); + void coord_update(MPI_Comm base_comm, atl_proc_coord_t& coord); + atl_status_t ep_init(std::vector& eps); + atl_status_t comm_split(const std::vector& base_eps, + std::vector& eps, + size_t color); + + static atl_mpi_env_info_t get_env_info(const char* key); + static atl_mpi_comm_info_t get_comm_info(MPI_Comm comm, const char* key); + private: - atl_ctx_t* ctx = nullptr; + MPI_Datatype atl2mpi_dtype(atl_datatype_t dtype); + void init_req(atl_req_t* req); + inline atl_status_t ep_progress(atl_mpi_ep_t& ep, atl_mpi_req_t* req); + MPI_Op atl2mpi_op(atl_reduction_t rtype, MPI_Datatype dtype); + void check_comm_nic_idx(MPI_Comm comm, size_t expected_idx); + void check_comm_ep_idx(MPI_Comm comm, size_t expected_idx); + void check_comm_info(MPI_Comm comm, const char* key, const char* expected_value); + size_t get_ep_idx(size_t ep_idx); + +#ifdef ENABLE_DEBUG + void check_ep(atl_mpi_ep_t& ep); +#else +#define check_ep(ep) +#endif + bool is_finalized{ false }; bool inited{ false }; + static atl_mpi_global_data global_data; + atl_progress_mode_t progress_mode; + bool sync_coll; + size_t ep_count; + atl_proc_coord_t global_coord; }; - #endif // CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi_comm.cpp b/src/atl/mpi/atl_mpi_comm.cpp new file mode 100644 index 000000000..a7ed97ad8 --- /dev/null +++ b/src/atl/mpi/atl_mpi_comm.cpp @@ -0,0 +1,127 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#ifdef CCL_ENABLE_MPI + +#include "atl/mpi/atl_mpi_comm.hpp" +#include "exec/exec.hpp" + +std::atomic atl_mpi_comm::comm_count{ 0 }; +atl_mpi* atl_mpi_comm::transport{ nullptr }; + +atl_mpi_comm::~atl_mpi_comm() { + static std::mutex memory_mutex; + std::lock_guard lock(memory_mutex); + tag.reset(); + comm_count--; + if (comm_count.load() == 0) { + delete transport; + transport = nullptr; + } +} + +atl_mpi_comm::atl_mpi_comm() { + init_transport(true); +} + +atl_mpi_comm::atl_mpi_comm(std::shared_ptr k) : atl_mpi_comm() { + (void)k; +} + +atl_mpi_comm::atl_mpi_comm(int total_rank_count, + const std::vector& ranks, + std::shared_ptr k) + : atl_mpi_comm() { + (void)total_rank_count; + (void)ranks; + (void)k; +} + +atl_mpi_comm::atl_mpi_comm(std::vector& parent_eps, + int parent_rank, + int parent_size, + int color) { + this->parent_rank = parent_rank; + this->parent_size = parent_size; + + transport->comm_split(parent_eps, eps, color); + transport->coord_update(eps[0].mpi_comm, coord); + rank = coord.global_idx; + size = coord.global_count; + init_transport(false); + rank2rank_map.resize(size); + MPI_Allgather(&parent_rank, 1, MPI_INT, rank2rank_map.data(), 1, MPI_INT, eps[0].mpi_comm); +} + +void atl_mpi_comm::eps_update() { + for (auto& ep : eps) { + ep.coord = &coord; + } +} + +std::shared_ptr atl_mpi_comm::comm_split(int color) { + std::shared_ptr comm = + std::shared_ptr(new atl_mpi_comm(eps, parent_rank, parent_size, color)); + + return static_cast>(comm); +} + +void atl_mpi_comm::init_transport(bool is_new) { + LOG_DEBUG("init ATL, requested ep_count ", attr.in.ep_count); + if (is_new) { + static std::mutex memory_mutex; + { + std::lock_guard lock(memory_mutex); + if (!transport) { + transport = new atl_mpi(); + } + if (!transport->is_inited()) { + CCL_THROW_IF_NOT( + transport->init(nullptr, nullptr, &attr, nullptr, pmi) == ATL_STATUS_SUCCESS, + "failed to initialize ATL"); + + int mpi_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); + if (mpi_rank == 0) { + print_atl_attrs(); + } + } + } + + transport->ep_init(eps); + transport->coord_update(MPI_COMM_WORLD, coord); + parent_rank = rank = coord.global_idx; + parent_size = size = coord.global_count; + rank2rank_map.resize(size); + + for (int i = 0; i < size; i++) { + rank2rank_map[i] = i; + } + } + + threads_per_process = 1; + ranks_per_process = 1; + + eps_update(); + init_tag(); + + comm_count++; + + executor_update(); +} +std::vector atl_mpi_comm::get_rank2rank_map() { + return rank2rank_map; +} +#endif //CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi_comm.hpp b/src/atl/mpi/atl_mpi_comm.hpp new file mode 100644 index 000000000..564f4b0b8 --- /dev/null +++ b/src/atl/mpi/atl_mpi_comm.hpp @@ -0,0 +1,251 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifdef CCL_ENABLE_MPI + +#include + +#include "atl/atl_base_comm.hpp" +#include "atl/mpi/atl_mpi.hpp" + +class atl_mpi_comm : public atl_base_comm { +public: + ~atl_mpi_comm() override; + + atl_mpi_comm(); + atl_mpi_comm(std::shared_ptr k); + atl_mpi_comm(int total_rank_count, + const std::vector& ranks, + std::shared_ptr k); + + atl_status_t main_addr_reserve(char* main_addr) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t finalize() override { + transport->comms_free(eps); + return ATL_STATUS_SUCCESS; + } + + atl_status_t update() override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t wait_notification() override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t set_resize_function(atl_resize_fn_t fn) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t mr_reg(const void* buf, size_t len, atl_mr_t** mr) override { + return transport->mr_reg(buf, len, mr); + } + + atl_status_t mr_dereg(atl_mr_t* mr) override { + return transport->mr_dereg(mr); + } + + atl_status_t send(size_t ep_idx, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req) override { + return transport->send(eps[ep_idx], buf, len, dst_proc_idx, tag, req); + } + + atl_status_t recv(size_t ep_idx, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req) override { + return transport->recv(eps[ep_idx], buf, len, src_proc_idx, tag, req); + } + + atl_status_t probe(size_t ep_idx, + int src_proc_idx, + uint64_t tag, + int* found, + size_t* recv_len) override { + return transport->probe(eps[ep_idx], src_proc_idx, tag, found, recv_len); + } + + atl_status_t allgatherv(size_t ep_idx, + const void* send_buf, + size_t send_len, + void* recv_buf, + const int* recv_lens, + const int* offsets, + atl_req_t* req) override { + return transport->allgatherv( + eps[ep_idx], send_buf, send_len, recv_buf, recv_lens, offsets, req); + } + + atl_status_t allreduce(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) override { + return transport->allreduce(eps[ep_idx], send_buf, recv_buf, len, dtype, op, req); + } + + atl_status_t alltoall(size_t ep_idx, + const void* send_buf, + void* recv_buf, + int len, + atl_req_t* req) override { + return transport->alltoall(eps[ep_idx], send_buf, recv_buf, len, req); + } + + atl_status_t alltoallv(size_t ep_idx, + const void* send_buf, + const int* send_lens, + const int* send_offsets, + void* recv_buf, + const int* recv_lens, + const int* recv_offsets, + atl_req_t* req) override { + return transport->alltoallv( + eps[ep_idx], send_buf, send_lens, send_offsets, recv_buf, recv_lens, recv_offsets, req); + } + + atl_status_t barrier(size_t ep_idx, atl_req_t* req) override { + return transport->barrier(eps[ep_idx], req); + } + + atl_status_t bcast(size_t ep_idx, void* buf, size_t len, int root, atl_req_t* req) override { + return transport->bcast(eps[ep_idx], buf, len, root, req); + } + + atl_status_t reduce(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t len, + int root, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) override { + return transport->reduce(eps[ep_idx], send_buf, recv_buf, len, root, dtype, op, req); + } + + atl_status_t reduce_scatter(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t recv_len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) override { + return transport->reduce_scatter(eps[ep_idx], send_buf, recv_buf, recv_len, dtype, op, req); + } + + atl_status_t read(size_t ep_idx, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) override { + return transport->read(eps[ep_idx], buf, len, mr, addr, remote_key, dst_proc_idx, req); + } + + atl_status_t write(size_t ep_idx, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) override { + return transport->write(eps[ep_idx], buf, len, mr, addr, remote_key, dst_proc_idx, req); + } + + atl_status_t wait(size_t ep_idx, atl_req_t* req) override { + return transport->wait(eps[ep_idx], req); + } + + atl_status_t wait_all(size_t ep_idx, atl_req_t* req, size_t count) override { + return transport->wait_all(eps[ep_idx], req, count); + } + + atl_status_t cancel(size_t ep_idx, atl_req_t* req) override { + return transport->cancel(eps[ep_idx], req); + } + + atl_status_t poll(size_t ep_idx) override { + return transport->poll(eps[ep_idx]); + } + + atl_status_t check(size_t ep_idx, atl_req_t* req) override { + return transport->check(eps[ep_idx], req); + } + + size_t get_threads_per_process() override { + return threads_per_process; + } + + size_t get_ranks_per_process() override { + return ranks_per_process; + } + + int get_rank() override { + return rank; + } + + int get_size() override { + return size; + } + + int get_r2r_color() override { + return coord.local_idx; + } + + int get_host_color() override { + return coord.hostname_hash; + } + + /* + * TODO: Temporary change. + * Need to define correct to unique id + */ + size_t get_id() override { + return 0; + } + + std::shared_ptr comm_split(int color) override; + + std::vector get_rank2rank_map() override; + +private: + atl_mpi_comm(std::vector& parent_eps, + int parent_rank, + int parent_size, + int color); + void eps_update(); + std::vector eps; + static atl_mpi* transport; + static std::atomic comm_count; + + void init_transport(bool is_new); +}; + +#endif //CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi_global_data.cpp b/src/atl/mpi/atl_mpi_global_data.cpp new file mode 100644 index 000000000..d45da0028 --- /dev/null +++ b/src/atl/mpi/atl_mpi_global_data.cpp @@ -0,0 +1,700 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#ifdef CCL_ENABLE_MPI + +#include "atl/mpi/atl_mpi.hpp" +#include "atl/mpi/atl_mpi_global_data.hpp" +#include "common/global/global.hpp" +#include "common/log/log.hpp" + +void check_op_params(void* in_buf, + void* inout_buf, + int* length, + MPI_Datatype* datatype, + const char* caller_func_name) { + (void)datatype; + CCL_THROW_IF_NOT(in_buf && inout_buf && length, + caller_func_name, + " requested, bad arguments: ", + in_buf, + " ", + inout_buf, + " ", + length); +} + +#ifdef ATL_MPI_FP16 + +void FP16_INLINE_TARGET_ATTRIBUTE_ALL fp16_base_op(void* in, + void* inout, + int* length, + ccl::reduction op) { + unsigned short* in_buf = (unsigned short*)in; + unsigned short* inout_buf = (unsigned short*)inout; + + size_t len = *length; + ccl_fp16_reduce_impl(in_buf, inout_buf, len, op); +} + +void FP16_TARGET_ATTRIBUTE_ALL fp16_sum_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + fp16_base_op(in, inout, length, ccl::reduction::sum); +} + +void FP16_TARGET_ATTRIBUTE_ALL fp16_prod_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + fp16_base_op(in, inout, length, ccl::reduction::prod); +} + +void FP16_TARGET_ATTRIBUTE_ALL fp16_min_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + fp16_base_op(in, inout, length, ccl::reduction::min); +} + +void FP16_TARGET_ATTRIBUTE_ALL fp16_max_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + fp16_base_op(in, inout, length, ccl::reduction::max); +} +#endif // ATL_MPI_FP16 + +#ifdef ATL_MPI_BF16 + +void BF16_INLINE_TARGET_ATTRIBUTE_ALL bf16_base_op(void* in, + void* inout, + int* length, + ccl::reduction op) { + unsigned short* in_buf = (unsigned short*)in; + unsigned short* inout_buf = (unsigned short*)inout; + + size_t len = *length; + ccl_bf16_reduce_impl(in_buf, inout_buf, len, op); +} + +void BF16_TARGET_ATTRIBUTE_ALL bf16_sum_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + bf16_base_op(in, inout, length, ccl::reduction::sum); +} + +void BF16_TARGET_ATTRIBUTE_ALL bf16_prod_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + bf16_base_op(in, inout, length, ccl::reduction::prod); +} + +void BF16_TARGET_ATTRIBUTE_ALL bf16_min_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + bf16_base_op(in, inout, length, ccl::reduction::min); +} + +void BF16_TARGET_ATTRIBUTE_ALL bf16_max_op(void* in, + void* inout, + int* length, + MPI_Datatype* datatype) { + check_op_params(in, inout, length, datatype, __FUNCTION__); + bf16_base_op(in, inout, length, ccl::reduction::max); +} +#endif // ATL_MPI_BF16 + +void atl_mpi_global_data::print_error(int error) { + char str_error[MPI_MAX_ERROR_STRING]; + int result_len = MPI_MAX_ERROR_STRING; + + MPI_Error_string(error, str_error, &result_len); + + if (result_len > MPI_MAX_ERROR_STRING) { + result_len = MPI_MAX_ERROR_STRING; + } + str_error[result_len - 1] = '\0'; + + ccl_logger::format(std::cout, "MPI error: %s (%d)", str_error, error); +} + +atl_status_t atl_mpi_global_data::set_impi_env(const atl_attr_t& attr, + const atl_mpi_lib_attr_t& lib_attr) { + char ep_count_str[MPI_MAX_INFO_VAL] = { 0 }; + snprintf(ep_count_str, MPI_MAX_INFO_VAL, "%zu", get_ep_count(attr)); + + if (attr.in.ep_count) + setenv("I_MPI_OFI_ISEND_INJECT_THRESHOLD", "0", 0); + +#ifdef CCL_ENABLE_SYCL + setenv("I_MPI_SHM_CMA", "0", 0); + if (attr.in.enable_hmem && lib_attr.hmem) { + setenv("I_MPI_OFFLOAD", "2", 0); + setenv("I_MPI_OFFLOAD_TOPOLIB", "l0", 0); + setenv("I_MPI_OFFLOAD_QUEUE_CACHE", "1", 0); + setenv("I_MPI_OFFLOAD_LIST_CACHE", "1", 0); + setenv("I_MPI_OFFLOAD_MEMCPY_KIND", "blocked", 0); + if (attr.in.ep_count > 1) { + /* try to set global lock level before vci level + because setenv is invoked with overwrite=0 */ + setenv("I_MPI_THREAD_LOCK_LEVEL", "global", 0); + } + } +#endif // CCL_ENABLE_SYCL + + setenv("I_MPI_THREAD_SPLIT", "1", 0); + setenv("I_MPI_THREAD_RUNTIME", "generic", 0); + setenv("I_MPI_THREAD_MAX", ep_count_str, 0); + setenv("I_MPI_THREAD_ID_KEY", EP_IDX_KEY, 0); + setenv("I_MPI_THREAD_LOCK_LEVEL", "vci", 0); + + return ATL_STATUS_SUCCESS; +} + +size_t atl_mpi_global_data::get_ep_count(const atl_attr_t& attr) { + size_t mpi_ep_count = attr.in.ep_count; + if (attr.in.enable_extra_ep) + mpi_ep_count += attr.in.enable_extra_ep; + return mpi_ep_count; +} + +atl_mpi_global_data::atl_mpi_lib_attr_t atl_mpi_global_data::get_lib_attr() { + atl_mpi_lib_attr_t lib_attr = { ATL_MPI_LIB_NONE, 0 }; + + char mpi_version[MPI_MAX_LIBRARY_VERSION_STRING] = { 0 }; + int mpi_version_len = -1, i; + const atl_mpi_lib_info_t* final_info = NULL; + + /* can be called before MPI_Init */ + int ret = MPI_Get_library_version(mpi_version, &mpi_version_len); + + if ((ret != MPI_SUCCESS) || (mpi_version_len < 0) || + (mpi_version_len > MPI_MAX_LIBRARY_VERSION_STRING)) { + LOG_WARN("can not retrieve MPI version, mpi_version_len ", mpi_version_len, ", ret", ret); + return lib_attr; + } + + /* remove trailing spaces at the end for more compact log */ + while (strlen(mpi_version) && isspace(mpi_version[strlen(mpi_version) - 1])) + mpi_version[strlen(mpi_version) - 1] = '\0'; + + LOG_DEBUG("MPI version: ", mpi_version); + + /* for filtering */ + char* lib_type_env = getenv("CCL_ATL_MPI"); + + for (i = 0; i < MPI_LIB_INFO_MAX_COUNT; i++) { + const atl_mpi_lib_info_t* info = &(mpi_lib_infos[i]); + + if (info->type == ATL_MPI_LIB_NONE) + continue; + + if (lib_type_env) { + if (strcmp(lib_type_env, info->name)) { + LOG_DEBUG("library ", info->name, " is filtered out by user input ", lib_type_env); + continue; + } + else { + LOG_DEBUG("use lib_type = ", lib_type_env, " because it is requested explicitly"); + } + } + + CCL_THROW_IF_NOT(info->version_prefix_1, "empty version_prefix_1"); + CCL_THROW_IF_NOT(info->min_version_value >= 0, "unexpected minimal version"); + + const char* version_substr = NULL; + if ((version_substr = strstr(mpi_version, info->version_prefix_1))) { + version_substr += strlen(info->version_prefix_1); + LOG_DEBUG("version_substr: ", version_substr); + + if (info->version_prefix_2) { + version_substr = strstr(version_substr, info->version_prefix_2); + if (!version_substr) { + LOG_DEBUG("can't find version_prefix_2 ", info->version_prefix_2); + continue; + } + version_substr += strlen(info->version_prefix_2); + LOG_DEBUG("version_substr: ", version_substr); + } + + int version_value = (version_substr) ? atoi(version_substr) : -1; + LOG_DEBUG("MPI numerical version: ", version_value); + + if (version_value < info->min_version_value) { + LOG_WARN("loaded MPI doesn't match with expected version, " + "consider to switch to ", + info->version_prefix_1, + " ", + (info->version_prefix_2 ? info->version_prefix_2 : ""), + info->min_version_value, + " (min) ", + (info->kind_value ? info->kind_value : ""), + "\n"); + continue; + } + + if (info->kind_prefix && info->kind_value) { + const char* kind_substr = mpi_version; + + if ((kind_substr = strstr(kind_substr, info->kind_prefix))) { + kind_substr += strlen(info->kind_prefix); + while ((isspace(*kind_substr)) && + (kind_substr < (mpi_version + mpi_version_len))) + kind_substr++; + + LOG_DEBUG("kind_substr: ", kind_substr); + + if (strncmp(kind_substr, info->kind_value, strlen(info->kind_value))) { + LOG_WARN("loaded MPI version (", + version_value, + ") ", + "is higher or equal to minimal expected version (", + info->min_version_value, + ") ", + "but kind (", + kind_substr, + ") doesn't match with expected kind (", + info->kind_value, + "), " + "consider to switch to ", + info->version_prefix_1, + " ", + (info->version_prefix_2 ? info->version_prefix_2 : ""), + info->min_version_value, + " (min version) ", + (info->kind_value ? info->kind_value : ""), + "\n"); + } + } + else { + LOG_DEBUG("MPI version is high enough, but kind_prefix (", + info->kind_prefix, + ") can not be found", + " treat this like expected kind (", + info->kind_value, + ") was found"); + } + } + + final_info = info; + LOG_DEBUG("set lib_type = ", + info->name, + " because " + "version (", + version_value, + ") is higher or equal to minimal expected version (", + info->min_version_value, + ")"); + + lib_attr.type = final_info->type; + lib_attr.hmem = (final_info->min_hmem_version_value >= version_value) ? 1 : 0; + + break; + } + } + + if (final_info) { + LOG_DEBUG("MPI library type: ", final_info->name); + } + else { + LOG_DEBUG("MPI library type: none"); + } + + return lib_attr; +} + +int atl_mpi_global_data::bf16_init() { + if (ccl::global_data::env().bf16_impl_type <= ccl_bf16_no_hardware_support) { + return ATL_STATUS_SUCCESS; + } + +#ifdef ATL_MPI_BF16 + + int ret = MPI_SUCCESS; + // create custom MPI BF16 dtype + ret = MPI_Type_contiguous(2, MPI_BYTE, &bf16.dtype); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI BF16 dtype"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + ret = MPI_Type_commit(&bf16.dtype); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot commit MPI BF16 type"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI BF16 summation op + ret = MPI_Op_create(&bf16_sum_op, 1, &bf16.sum_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI BF16 sum op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI BF16 production op + ret = MPI_Op_create(&bf16_prod_op, 1, &bf16.prod_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI BF16 prod op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI BF16 min op + ret = MPI_Op_create(&bf16_min_op, 1, &bf16.min_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI BF16 min op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI BF16 max op + ret = MPI_Op_create(&bf16_max_op, 1, &bf16.max_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI BF16 max op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + +#endif // ATL_MPI_BF16 + + return ATL_STATUS_SUCCESS; +} + +void atl_mpi_global_data::bf16_finalize() { + if (bf16.dtype != MPI_DATATYPE_NULL) { + MPI_Type_free(&bf16.dtype); + } + + if (bf16.sum_op != MPI_OP_NULL) { + MPI_Op_free(&bf16.sum_op); + } + + if (bf16.prod_op != MPI_OP_NULL) { + MPI_Op_free(&bf16.prod_op); + } + + if (bf16.min_op != MPI_OP_NULL) { + MPI_Op_free(&bf16.min_op); + } + + if (bf16.max_op != MPI_OP_NULL) { + MPI_Op_free(&bf16.max_op); + } +} + +int atl_mpi_global_data::fp16_init() { + if (ccl::global_data::env().fp16_impl_type <= ccl_fp16_no_hardware_support) { + return ATL_STATUS_SUCCESS; + } + +#ifdef ATL_MPI_FP16 + + int ret = MPI_SUCCESS; + + // create custom MPI FP16 dtype + ret = MPI_Type_contiguous(2, MPI_BYTE, &fp16.dtype); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI FP16 dtype"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + ret = MPI_Type_commit(&fp16.dtype); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot commit MPI FP16 type"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI FP16 summation op + ret = MPI_Op_create(&fp16_sum_op, 1, &fp16.sum_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI FP16 sum op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI FP16 production op + ret = MPI_Op_create(&fp16_prod_op, 1, &fp16.prod_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI FP16 prod op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI FP16 min op + ret = MPI_Op_create(&fp16_min_op, 1, &fp16.min_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI FP16 min op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + + // create custom MPI FP16 max op + ret = MPI_Op_create(&fp16_max_op, 1, &fp16.max_op); + if (ret != MPI_SUCCESS) { + LOG_ERROR("cannot create MPI FP16 max op"); + print_error(ret); + return ATL_STATUS_FAILURE; + } + +#endif // ATL_MPI_FP16 + + return ATL_STATUS_SUCCESS; +} + +void atl_mpi_global_data::fp16_finalize() { + if (fp16.dtype != MPI_DATATYPE_NULL) { + MPI_Type_free(&fp16.dtype); + } + + if (fp16.sum_op != MPI_OP_NULL) { + MPI_Op_free(&fp16.sum_op); + } + + if (fp16.prod_op != MPI_OP_NULL) { + MPI_Op_free(&fp16.prod_op); + } + + if (fp16.min_op != MPI_OP_NULL) { + MPI_Op_free(&fp16.min_op); + } + + if (fp16.max_op != MPI_OP_NULL) { + MPI_Op_free(&fp16.max_op); + } +} + +atl_status_t atl_mpi_global_data::check_impi_env(const atl_attr_t& attr) { + char* ep_count_env = getenv("I_MPI_THREAD_MAX"); + if (!ep_count_env) + return ATL_STATUS_FAILURE; + if (atoi(ep_count_env) != (int)(get_ep_count(attr))) + return ATL_STATUS_FAILURE; + + if (!getenv("I_MPI_ROOT")) { + atl_mpi_lib_type_t type = ATL_MPI_LIB_IMPI; + LOG_ERROR("CCL/MPI uses ", + mpi_lib_infos[type].version_prefix_1, + " but I_MPI_ROOT is not set. ", + "Please source ", + mpi_lib_infos[type].kind_value, + " version of ", + mpi_lib_infos[type].version_prefix_1, + " (", + mpi_lib_infos[type].min_version_value, + " or higher version)."); + return ATL_STATUS_FAILURE; + } + + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_mpi_global_data::update_global_data(atl_attr_t* attr) { + if (mpi_lib_attr.type == ATL_MPI_LIB_NONE) + mpi_lib_attr = get_lib_attr(); + + extra_ep = attr->in.enable_extra_ep; + + mnic_type = attr->in.mnic_type; + if (mpi_lib_attr.type != ATL_MPI_LIB_MPICH) { + /* only MPICH supports multi-NIC */ + mnic_type = ATL_MNIC_NONE; + } + + if (mnic_type == ATL_MNIC_LOCAL) { + mnic_count = get_nic_count(LOCAL_NIC_COUNT_KEY); + } + else if (mnic_type == ATL_MNIC_GLOBAL) { + mnic_count = get_nic_count(GLOBAL_NIC_COUNT_KEY); + } + else if (mnic_type == ATL_MNIC_NONE) { + mnic_count = 1; + } + mnic_count = std::min(mnic_count, attr->in.mnic_count); + mnic_count = std::max(mnic_count, (size_t)(1)); + mnic_offset = attr->in.mnic_offset; + + if (bf16_init() == ATL_STATUS_FAILURE) { + bf16_finalize(); + return ATL_STATUS_FAILURE; + } + + if (fp16_init() == ATL_STATUS_FAILURE) { + fp16_finalize(); + return ATL_STATUS_FAILURE; + } + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_mpi_global_data::set_mpich_env(const atl_attr_t& attr) { + char ep_count_str[MPI_MAX_INFO_VAL] = { 0 }; + snprintf(ep_count_str, MPI_MAX_INFO_VAL, "%zu", get_ep_count(attr)); + + setenv("MPIR_CVAR_CH4_MT_MODEL", "direct", 0); + setenv("MPIR_CVAR_CH4_NUM_VCIS", ep_count_str, 0); + setenv("MPIR_CVAR_CH4_OFI_MAX_VCIS", ep_count_str, 0); + setenv("MPIR_COMM_HINT_VCI", EP_IDX_KEY, 0); + + auto& env = ccl::global_data::env(); + if (env.log_level >= ccl_log_level::debug) { + setenv("MPIR_CVAR_CH4_RUNTIME_CONF_DEBUG", "1", 0); + setenv("MPIR_CVAR_CH4_OFI_CAPABILITY_SETS_DEBUG", "1", 0); + setenv("MPIR_CVAR_DEBUG_SUMMARY", "1", 0); + } + + setenv("FI_PSM2_DELAY", "0", 0); + setenv("FI_PSM2_TIMEOUT", "0", 0); + setenv("FI_PSM2_NAME_SERVER", "0", 0); + setenv("HFI_NO_CPUAFFINITY", "1", 0); + + return ATL_STATUS_SUCCESS; +} + +/* set these knobs without detection of MPI library type */ +atl_status_t atl_mpi_global_data::set_base_env(const atl_attr_t& attr) { + setenv("PSM2_MULTI_EP", "1", 0); + setenv("FI_OFI_RXM_USE_HASH", "0", 0); + +#ifdef CCL_ENABLE_SYCL + setenv("FI_SHM_DISABLE_CMA", "1", 0); +#endif // CCL_ENABLE_SYCL + + setenv("MPIR_CVAR_DEFAULT_THREAD_LEVEL", "MPI_THREAD_MULTIPLE", 0); + + /* request IMPI level append library kind into MPI_Get_library_version output */ + setenv("I_MPI_INFO_LIBRARY_KIND", "1", 0); + + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_mpi_global_data::set_env(const atl_attr_t& attr) { + if (mpi_lib_attr.type != ATL_MPI_LIB_NONE) { + /* library type was already detected and env was set, make sanity check */ + if (mpi_lib_attr.type == ATL_MPI_LIB_IMPI) { + return check_impi_env(attr); + } + else if (mpi_lib_attr.type == ATL_MPI_LIB_MPICH) { + return check_mpich_env(attr); + } + return ATL_STATUS_SUCCESS; + } + + set_base_env(attr); + + mpi_lib_attr = get_lib_attr(); + + if (mpi_lib_attr.type == ATL_MPI_LIB_NONE) { + return ATL_STATUS_SUCCESS; + } + + if (mpi_lib_attr.type == ATL_MPI_LIB_IMPI) { + set_impi_env(attr, mpi_lib_attr); + check_impi_env(attr); + } + else if (mpi_lib_attr.type == ATL_MPI_LIB_MPICH) { + set_mpich_env(attr); + check_mpich_env(attr); + } + + int is_mpi_inited = 0; + MPI_Initialized(&is_mpi_inited); + if (is_mpi_inited) { + LOG_WARN("MPI was initialized externally, CCL-MPI specific environment is ignored"); + } + else { + LOG_DEBUG("set CCL-MPI specific environment"); + } + + return ATL_STATUS_SUCCESS; +} + +atl_status_t atl_mpi_global_data::check_mpich_env(const atl_attr_t& attr) { + char* ep_count_env = getenv("MPIR_CVAR_CH4_OFI_MAX_VCIS"); + if (!ep_count_env) + return ATL_STATUS_FAILURE; + if (atoi(ep_count_env) != (int)(get_ep_count(attr))) + return ATL_STATUS_FAILURE; + return ATL_STATUS_SUCCESS; +} + +#ifdef ATL_MPI_BF16 +MPI_Op atl_mpi_global_data::atl2mpi_op_bf16(atl_reduction_t rtype) { + switch (rtype) { + case ATL_REDUCTION_SUM: return bf16.sum_op; + case ATL_REDUCTION_PROD: return bf16.prod_op; + case ATL_REDUCTION_MIN: return bf16.min_op; + case ATL_REDUCTION_MAX: return bf16.max_op; + default: printf("unknown reduction type: %d\n", rtype); exit(1); + } +} +#endif // ATL_MPI_BF16 + +#ifdef ATL_MPI_FP16 +MPI_Op atl_mpi_global_data::atl2mpi_op_fp16(atl_reduction_t rtype) { + switch (rtype) { + case ATL_REDUCTION_SUM: return fp16.sum_op; + case ATL_REDUCTION_PROD: return fp16.prod_op; + case ATL_REDUCTION_MIN: return fp16.min_op; + case ATL_REDUCTION_MAX: return fp16.max_op; + default: printf("unknown reduction type: %d\n", rtype); exit(1); + } +} +#endif // ATL_MPI_FP16 + +void atl_mpi_global_data::print_log_info() { + if (ctx_count == 1) { + LOG_INFO("atl-mpi-global:") + LOG_INFO(" is_external_init: ", is_external_init); + LOG_INFO(" mpi_lib_attr.type: ", mpi_lib_infos[mpi_lib_attr.type].name); + LOG_INFO(" mpi_lib_attr.hmem: ", mpi_lib_attr.hmem); + LOG_INFO(" extra_ep: ", extra_ep); + LOG_INFO(" mnic_type: ", to_string(mnic_type)); + if (mnic_type != ATL_MNIC_NONE) { + LOG_INFO(" mnic_count: ", mnic_count); + LOG_INFO(" mnic_offset: ", to_string(mnic_offset)); + } + } +} + +size_t atl_mpi_global_data::get_nic_count(const char* nic_count_key) { + size_t count = 1; + atl_mpi_env_info_t info = atl_mpi::get_env_info(nic_count_key); + CCL_THROW_IF_NOT(info.found, "MPI env key ", nic_count_key, " was not set"); + + count = atoi(info.value); + if (count <= 0) { + count = 1; + } + + return count; +} + +#endif // CCL_ENABLE_MPI diff --git a/src/atl/mpi/atl_mpi_global_data.hpp b/src/atl/mpi/atl_mpi_global_data.hpp new file mode 100644 index 000000000..e0affbe5c --- /dev/null +++ b/src/atl/mpi/atl_mpi_global_data.hpp @@ -0,0 +1,161 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#ifdef CCL_ENABLE_MPI + +#include + +#include "atl/atl_def.h" +#include "comp/bf16/bf16_intrisics.hpp" +#include "comp/fp16/fp16_intrisics.hpp" + +#ifdef CCL_BF16_COMPILER +#define ATL_MPI_BF16 +#endif // CCL_BF16_COMPILER + +#ifdef CCL_FP16_COMPILER +#define ATL_MPI_FP16 +#endif // CCL_FP16_COMPILER + +class atl_mpi_global_data { +public: + typedef enum { ATL_MPI_LIB_IMPI, ATL_MPI_LIB_MPICH, ATL_MPI_LIB_NONE } atl_mpi_lib_type_t; + +private: + typedef struct { + atl_mpi_lib_type_t type; + int hmem; + } atl_mpi_lib_attr_t; + typedef struct { + // custom MPI operations for BF16 + MPI_Op sum_op; + MPI_Op prod_op; + MPI_Op min_op; + MPI_Op max_op; + // custom MPI dtype for BF16 + MPI_Datatype dtype; + } atl_mpi_bf16_data_t; + + typedef struct { + // custom MPI operations for FP16 + MPI_Op sum_op; + MPI_Op prod_op; + MPI_Op min_op; + MPI_Op max_op; + // custom MPI dtype for FP16 + MPI_Datatype dtype; + } atl_mpi_fp16_data_t; + + typedef struct { + atl_mpi_lib_type_t type; + const char* name; + + /* string prefix before numerical version of library, mandatory */ + const char* version_prefix_1; + + /* string prefix before numerical version of library, following prefix_1, optional */ + const char* version_prefix_2; + + /* minimal expected version of library, mandatory */ + int min_version_value; + + /* minimal expected version of library with hmem support, mandatory */ + int min_hmem_version_value; + + /* string prefix before library kind, optional */ + const char* kind_prefix; + + /* library kind, optional */ + const char* kind_value; + } atl_mpi_lib_info_t; + +#define MPI_LIB_INFO_MAX_COUNT 3 + + const atl_mpi_lib_info_t mpi_lib_infos[MPI_LIB_INFO_MAX_COUNT] = { + { ATL_MPI_LIB_IMPI, + "impi", + "Intel(R) MPI Library", + NULL, + 2019, + 2021, + "library kind:", + "release" }, + { ATL_MPI_LIB_MPICH, "mpich", "MPICH Custom Information:", "drop", 34, -1, NULL, NULL }, + { ATL_MPI_LIB_NONE, "none", "", NULL, 0, -1, NULL, NULL }, + }; + + size_t get_nic_count(const char* nic_count_key); + +public: + const char* EP_IDX_KEY = "vci"; + + const char* NIC_IDX_KEY = "multi_nic_pref_nic"; + const char* GLOBAL_NIC_COUNT_KEY = "num_nics"; + const char* LOCAL_NIC_COUNT_KEY = "num_close_nics"; + + int is_external_init; + size_t ctx_count; + int extra_ep; + atl_mnic_t mnic_type; + size_t mnic_count; + atl_mnic_offset_t mnic_offset; + atl_mpi_lib_attr_t mpi_lib_attr; + atl_mpi_bf16_data_t bf16; + atl_mpi_fp16_data_t fp16; + + atl_mpi_global_data() + : is_external_init(0), + ctx_count(0), + extra_ep(0), + mnic_type(ATL_MNIC_NONE), + mnic_count(1), + mnic_offset(ATL_MNIC_OFFSET_NONE) { + mpi_lib_attr.type = ATL_MPI_LIB_NONE; + mpi_lib_attr.hmem = 0; + + bf16.dtype = MPI_DATATYPE_NULL; + bf16.sum_op = MPI_OP_NULL; + bf16.prod_op = MPI_OP_NULL; + bf16.min_op = MPI_OP_NULL; + bf16.max_op = MPI_OP_NULL; + + fp16.dtype = MPI_DATATYPE_NULL; + fp16.sum_op = MPI_OP_NULL; + fp16.prod_op = MPI_OP_NULL; + fp16.min_op = MPI_OP_NULL; + fp16.max_op = MPI_OP_NULL; + } + atl_mpi_lib_attr_t get_lib_attr(); + size_t get_ep_count(const atl_attr_t& attr); + atl_status_t set_impi_env(const atl_attr_t& attr, const atl_mpi_lib_attr_t& lib_attr); + int bf16_init(); + void bf16_finalize(); + int fp16_init(); + void fp16_finalize(); + void print_error(int error); + atl_status_t check_impi_env(const atl_attr_t& attr); + atl_status_t update_global_data(atl_attr_t* attr); + atl_status_t set_mpich_env(const atl_attr_t& attr); + atl_status_t set_base_env(const atl_attr_t& attr); + atl_status_t check_mpich_env(const atl_attr_t& attr); + atl_status_t set_env(const atl_attr_t& attr); + MPI_Op atl2mpi_op_fp16(atl_reduction_t rtype); + MPI_Op atl2mpi_op_bf16(atl_reduction_t rtype); + void print_log_info(); +}; + +#endif diff --git a/src/atl/mpi/atl_mpi_impl.cpp b/src/atl/mpi/atl_mpi_impl.cpp deleted file mode 100644 index 95636ac82..000000000 --- a/src/atl/mpi/atl_mpi_impl.cpp +++ /dev/null @@ -1,1711 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifdef CCL_ENABLE_MPI - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "atl.h" -#include "common/global/global.hpp" -#include "comp/bf16/bf16_intrisics.hpp" -#include "comp/bf16/bf16_utils.hpp" -#include "comp/fp16/fp16_intrisics.hpp" -#include "comp/fp16/fp16_utils.hpp" - -#define ATL_MPI_PM_KEY "atl-mpi" - -#define EP_IDX_KEY "ep_idx" - -#define GLOBAL_NIC_IDX_KEY "pref_nic" -#define GLOBAL_NIC_COUNT_KEY "num_nics" -#define LOCAL_NIC_IDX_KEY "pref_close_nic" -#define LOCAL_NIC_COUNT_KEY "num_close_nics" - -#define RET2ATL(ret) (ret != MPI_SUCCESS) ? ATL_STATUS_FAILURE : ATL_STATUS_SUCCESS - -typedef enum { ATL_MPI_LIB_IMPI, ATL_MPI_LIB_MPICH, ATL_MPI_LIB_NONE } atl_mpi_lib_type_t; - -typedef struct { - atl_mpi_lib_type_t type; - int hmem; -} atl_mpi_lib_attr_t; - -typedef struct { - atl_mpi_lib_type_t type; - const char* name; - - /* string prefix before numerical version of library, mandatory */ - const char* version_prefix_1; - - /* string prefix before numerical version of library, following prefix_1, optional */ - const char* version_prefix_2; - - /* minimal expected version of library, mandatory */ - int min_version_value; - - /* minimal expected version of library with hmem support, mandatory */ - int min_hmem_version_value; - - /* string prefix before library kind, optional */ - const char* kind_prefix; - - /* library kind, optional */ - const char* kind_value; -} atl_mpi_lib_info_t; - -#define MPI_LIB_INFO_MAX_COUNT 3 - -static atl_mpi_lib_info_t mpi_lib_infos[MPI_LIB_INFO_MAX_COUNT] = { - { ATL_MPI_LIB_IMPI, - "impi", - "Intel(R) MPI Library", - NULL, - 2019, - 2021, - "library kind:", - "release_mt" }, - { ATL_MPI_LIB_MPICH, "mpich", "MPICH Custom Information:", "drop", 34, -1, NULL, NULL }, - { ATL_MPI_LIB_NONE, "none", "", NULL, 0, -1, NULL, NULL }, -}; - -#ifdef CCL_BF16_COMPILER -#define ATL_MPI_BF16 -#endif // CCL_BF16_COMPILER - -#ifdef CCL_FP16_COMPILER -#define ATL_MPI_FP16 -#endif // CCL_FP16_COMPILER - -typedef struct { - // custom MPI operations for BF16 - MPI_Op sum_op; - MPI_Op prod_op; - MPI_Op min_op; - MPI_Op max_op; - // custom MPI dtype for BF16 - MPI_Datatype dtype; -} atl_mpi_bf16_data_t; - -typedef struct { - // custom MPI operations for FP16 - MPI_Op sum_op; - MPI_Op prod_op; - MPI_Op min_op; - MPI_Op max_op; - // custom MPI dtype for FP16 - MPI_Datatype dtype; -} atl_mpi_fp16_data_t; - -typedef struct atl_mpi_global_data { - int is_external_init; - size_t ctx_count; - int extra_ep; - atl_mnic_t mnic_type; - size_t mnic_count; - atl_mpi_lib_attr_t mpi_lib_attr; - atl_mpi_bf16_data_t bf16; - atl_mpi_fp16_data_t fp16; - - atl_mpi_global_data() - : is_external_init(0), - ctx_count(0), - extra_ep(0), - mnic_type(ATL_MNIC_NONE), - mnic_count(1) { - mpi_lib_attr.type = ATL_MPI_LIB_NONE; - mpi_lib_attr.hmem = 0; - - bf16.dtype = MPI_DATATYPE_NULL; - bf16.sum_op = MPI_OP_NULL; - bf16.prod_op = MPI_OP_NULL; - bf16.min_op = MPI_OP_NULL; - bf16.max_op = MPI_OP_NULL; - - fp16.dtype = MPI_DATATYPE_NULL; - fp16.sum_op = MPI_OP_NULL; - fp16.prod_op = MPI_OP_NULL; - fp16.min_op = MPI_OP_NULL; - fp16.max_op = MPI_OP_NULL; - } - -} atl_mpi_global_data_t; - -static atl_mpi_global_data_t global_data; - -typedef enum { ATL_MPI_COMP_POSTED, ATL_MPI_COMP_COMPLETED } atl_mpi_comp_state_t; - -typedef struct { - MPI_Request native_req; - atl_mpi_comp_state_t comp_state; -} atl_mpi_req_t; - -typedef struct { - atl_ctx_t ctx; - int sync_coll; - atl_progress_mode_t progress_mode; -} atl_mpi_ctx_t; - -typedef struct { - atl_ep_t ep; - MPI_Comm mpi_comm; - - /* dummy recv operation to ensure progress in atl_poll */ - atl_mpi_req_t dummy_req; - MPI_Comm dummy_comm; -} atl_mpi_ep_t; - -typedef struct atl_mpi_comm_info { - int found; - MPI_Comm comm; - char key[MPI_MAX_INFO_KEY]; - char value[MPI_MAX_INFO_VAL]; - - atl_mpi_comm_info() { - found = 0; - comm = MPI_COMM_WORLD; - memset(key, 0, MPI_MAX_INFO_KEY); - memset(value, 0, MPI_MAX_INFO_VAL); - } -} atl_mpi_comm_info_t; - -#define MPI_BFLOAT16 \ - ({ \ - CCL_THROW_IF_NOT(global_data.bf16.dtype != MPI_DATATYPE_NULL, \ - "unsupported datatype: ATL_DTYPE_BF16"); \ - global_data.bf16.dtype; \ - }) - -#define MPI_FLOAT16 \ - ({ \ - CCL_THROW_IF_NOT(global_data.fp16.dtype != MPI_DATATYPE_NULL, \ - "unsupported datatype: ATL_DTYPE_FP16"); \ - global_data.fp16.dtype; \ - }) - -// helpers: check contract -static inline void atl_mpi_check_op_params(void* in_buf, - void* inout_buf, - int* length, - MPI_Datatype* datatype, - const char* caller_func_name) { - (void)datatype; - CCL_THROW_IF_NOT(in_buf && inout_buf && length, - caller_func_name, - " requested, bad arguments: ", - in_buf, - " ", - inout_buf, - " ", - length); -} - -static void atl_mpi_print_error(int error) __attribute__((unused)); -static void atl_mpi_print_error(int error) { - char str_error[MPI_MAX_ERROR_STRING]; - int result_len = MPI_MAX_ERROR_STRING; - - MPI_Error_string(error, str_error, &result_len); - - if (result_len > MPI_MAX_ERROR_STRING) { - result_len = MPI_MAX_ERROR_STRING; - } - str_error[result_len - 1] = '\0'; - - ccl_logger::format(std::cout, "MPI error: %s (%d)", str_error, error); -} - -#ifdef ATL_MPI_BF16 - -static void BF16_INLINE_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_base_op(void* in, - void* inout, - int* length, - ccl::reduction op) { - unsigned short* in_buf = (unsigned short*)in; - unsigned short* inout_buf = (unsigned short*)inout; - - size_t len = *length; - ccl_bf16_reduce_impl(in_buf, inout_buf, len, op); -} - -static void BF16_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_sum_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_bf16_base_op(in, inout, length, ccl::reduction::sum); -} - -static void BF16_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_prod_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_bf16_base_op(in, inout, length, ccl::reduction::prod); -} - -static void BF16_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_min_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_bf16_base_op(in, inout, length, ccl::reduction::min); -} - -static void BF16_TARGET_ATTRIBUTE_ALL atl_mpi_bf16_max_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_bf16_base_op(in, inout, length, ccl::reduction::max); -} -#endif // ATL_MPI_BF16 - -#ifdef ATL_MPI_FP16 - -static void FP16_INLINE_TARGET_ATTRIBUTE_ALL atl_mpi_fp16_base_op(void* in, - void* inout, - int* length, - ccl::reduction op) { - unsigned short* in_buf = (unsigned short*)in; - unsigned short* inout_buf = (unsigned short*)inout; - - size_t len = *length; - ccl_fp16_reduce_impl(in_buf, inout_buf, len, op); -} - -static void FP16_TARGET_ATTRIBUTE_ALL atl_mpi_fp16_sum_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_fp16_base_op(in, inout, length, ccl::reduction::sum); -} - -static void FP16_TARGET_ATTRIBUTE_ALL atl_mpi_fp16_prod_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_fp16_base_op(in, inout, length, ccl::reduction::prod); -} - -static void FP16_TARGET_ATTRIBUTE_ALL atl_mpi_fp16_min_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_fp16_base_op(in, inout, length, ccl::reduction::min); -} - -static void FP16_TARGET_ATTRIBUTE_ALL atl_mpi_fp16_max_op(void* in, - void* inout, - int* length, - MPI_Datatype* datatype) { - atl_mpi_check_op_params(in, inout, length, datatype, __FUNCTION__); - atl_mpi_fp16_base_op(in, inout, length, ccl::reduction::max); -} -#endif // ATL_MPI_FP16 - -static int atl_mpi_bf16_init() { - int ret = MPI_SUCCESS; - - if (ccl::global_data::env().bf16_impl_type <= ccl_bf16_no_hardware_support) { - return RET2ATL(ret); - } - -#ifdef ATL_MPI_BF16 - - // create custom MPI BF16 dtype - ret = MPI_Type_contiguous(2, MPI_BYTE, &global_data.bf16.dtype); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI BF16 dtype"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - ret = MPI_Type_commit(&global_data.bf16.dtype); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot commit MPI BF16 type"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI BF16 summation op - ret = MPI_Op_create(&atl_mpi_bf16_sum_op, 1, &global_data.bf16.sum_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI BF16 sum op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI BF16 production op - ret = MPI_Op_create(&atl_mpi_bf16_prod_op, 1, &global_data.bf16.prod_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI BF16 prod op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI BF16 min op - ret = MPI_Op_create(&atl_mpi_bf16_min_op, 1, &global_data.bf16.min_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI BF16 min op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI BF16 max op - ret = MPI_Op_create(&atl_mpi_bf16_max_op, 1, &global_data.bf16.max_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI BF16 max op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - -#endif // ATL_MPI_BF16 - - return RET2ATL(ret); -} - -static void atl_mpi_bf16_finalize() { - if (global_data.bf16.dtype != MPI_DATATYPE_NULL) { - MPI_Type_free(&global_data.bf16.dtype); - } - - if (global_data.bf16.sum_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.bf16.sum_op); - } - - if (global_data.bf16.prod_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.bf16.prod_op); - } - - if (global_data.bf16.min_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.bf16.min_op); - } - - if (global_data.bf16.max_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.bf16.max_op); - } -} - -static int atl_mpi_fp16_init() { - int ret = MPI_SUCCESS; - - if (ccl::global_data::env().fp16_impl_type <= ccl_fp16_no_hardware_support) { - return RET2ATL(ret); - } - -#ifdef ATL_MPI_FP16 - - // create custom MPI FP16 dtype - ret = MPI_Type_contiguous(2, MPI_BYTE, &global_data.fp16.dtype); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI FP16 dtype"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - ret = MPI_Type_commit(&global_data.fp16.dtype); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot commit MPI FP16 type"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI FP16 summation op - ret = MPI_Op_create(&atl_mpi_fp16_sum_op, 1, &global_data.fp16.sum_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI FP16 sum op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI FP16 production op - ret = MPI_Op_create(&atl_mpi_fp16_prod_op, 1, &global_data.fp16.prod_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI FP16 prod op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI FP16 min op - ret = MPI_Op_create(&atl_mpi_fp16_min_op, 1, &global_data.fp16.min_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI FP16 min op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - - // create custom MPI FP16 max op - ret = MPI_Op_create(&atl_mpi_fp16_max_op, 1, &global_data.fp16.max_op); - if (ret != MPI_SUCCESS) { - LOG_ERROR("cannot create MPI FP16 max op"); - atl_mpi_print_error(ret); - return RET2ATL(ret); - } - -#endif // ATL_MPI_FP16 - - return RET2ATL(ret); -} - -static void atl_mpi_fp16_finalize() { - if (global_data.fp16.dtype != MPI_DATATYPE_NULL) { - MPI_Type_free(&global_data.fp16.dtype); - } - - if (global_data.fp16.sum_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.fp16.sum_op); - } - - if (global_data.fp16.prod_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.fp16.prod_op); - } - - if (global_data.fp16.min_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.fp16.min_op); - } - - if (global_data.fp16.max_op != MPI_OP_NULL) { - MPI_Op_free(&global_data.fp16.max_op); - } -} - -static MPI_Datatype atl2mpi_dtype(atl_datatype_t dtype) { - switch (dtype) { - case ATL_DTYPE_INT8: return MPI_CHAR; - case ATL_DTYPE_UINT8: return MPI_UNSIGNED_CHAR; - case ATL_DTYPE_INT16: return MPI_INT16_T; - case ATL_DTYPE_UINT16: return MPI_UINT16_T; - case ATL_DTYPE_INT32: return MPI_INT; - case ATL_DTYPE_UINT32: return MPI_UINT32_T; - case ATL_DTYPE_INT64: return MPI_LONG_LONG; - case ATL_DTYPE_UINT64: return MPI_UNSIGNED_LONG_LONG; - case ATL_DTYPE_FLOAT16: return MPI_FLOAT16; - case ATL_DTYPE_FLOAT32: return MPI_FLOAT; - case ATL_DTYPE_FLOAT64: return MPI_DOUBLE; - case ATL_DTYPE_BFLOAT16: return MPI_BFLOAT16; - default: printf("unknown datatype: %d\n", dtype); exit(1); - } -} - -#ifdef ATL_MPI_BF16 -static MPI_Op atl2mpi_op_bf16(atl_reduction_t rtype) { - switch (rtype) { - case ATL_REDUCTION_SUM: return global_data.bf16.sum_op; - case ATL_REDUCTION_PROD: return global_data.bf16.prod_op; - case ATL_REDUCTION_MIN: return global_data.bf16.min_op; - case ATL_REDUCTION_MAX: return global_data.bf16.max_op; - default: printf("unknown reduction type: %d\n", rtype); exit(1); - } -} -#endif // ATL_MPI_BF16 - -#ifdef ATL_MPI_FP16 -static MPI_Op atl2mpi_op_fp16(atl_reduction_t rtype) { - switch (rtype) { - case ATL_REDUCTION_SUM: return global_data.fp16.sum_op; - case ATL_REDUCTION_PROD: return global_data.fp16.prod_op; - case ATL_REDUCTION_MIN: return global_data.fp16.min_op; - case ATL_REDUCTION_MAX: return global_data.fp16.max_op; - default: printf("unknown reduction type: %d\n", rtype); exit(1); - } -} -#endif // ATL_MPI_FP16 - -static MPI_Op atl2mpi_op(atl_reduction_t rtype, MPI_Datatype dtype) { -#ifdef ATL_MPI_BF16 - if (dtype == global_data.bf16.dtype) - return atl2mpi_op_bf16(rtype); -#endif // ATL_MPI_BF16 - -#ifdef ATL_MPI_FP16 - if (dtype == global_data.fp16.dtype) - return atl2mpi_op_fp16(rtype); -#endif // ATL_MPI_FP16 - - (void)dtype; - switch (rtype) { - case ATL_REDUCTION_SUM: return MPI_SUM; - case ATL_REDUCTION_PROD: return MPI_PROD; - case ATL_REDUCTION_MIN: return MPI_MIN; - case ATL_REDUCTION_MAX: return MPI_MAX; - default: printf("unknown reduction type: %d\n", rtype); exit(1); - } -} - -atl_mpi_lib_attr_t atl_mpi_get_lib_attr() { - atl_mpi_lib_attr_t lib_attr = { ATL_MPI_LIB_NONE, 0 }; - - char mpi_version[MPI_MAX_LIBRARY_VERSION_STRING] = { 0 }; - int mpi_version_len = -1, i; - atl_mpi_lib_info_t* final_info = NULL; - - /* can be called before MPI_Init */ - int ret = MPI_Get_library_version(mpi_version, &mpi_version_len); - - if ((ret != MPI_SUCCESS) || (mpi_version_len < 0) || - (mpi_version_len > MPI_MAX_LIBRARY_VERSION_STRING)) { - LOG_WARN("can not retrieve MPI version, mpi_version_len ", mpi_version_len, ", ret", ret); - return lib_attr; - } - - /* remove trailing spaces at the end for more compact log */ - while (strlen(mpi_version) && isspace(mpi_version[strlen(mpi_version) - 1])) - mpi_version[strlen(mpi_version) - 1] = '\0'; - - LOG_DEBUG("MPI version: ", mpi_version); - - /* for filtering */ - char* lib_type_env = getenv("CCL_ATL_MPI"); - - for (i = 0; i < MPI_LIB_INFO_MAX_COUNT; i++) { - atl_mpi_lib_info_t* info = &(mpi_lib_infos[i]); - - if (info->type == ATL_MPI_LIB_NONE) - continue; - - if (lib_type_env) { - if (strcmp(lib_type_env, info->name)) { - LOG_DEBUG("library ", info->name, " is filtered out by user input ", lib_type_env); - continue; - } - else { - LOG_DEBUG("use lib_type = ", lib_type_env, " because it is requested explicitly"); - } - } - - CCL_THROW_IF_NOT(info->version_prefix_1, "empty version_prefix_1"); - CCL_THROW_IF_NOT(info->min_version_value >= 0, "unexpected minimal version"); - - const char* version_substr = NULL; - if ((version_substr = strstr(mpi_version, info->version_prefix_1))) { - version_substr += strlen(info->version_prefix_1); - LOG_DEBUG("version_substr: ", version_substr); - - if (info->version_prefix_2) { - version_substr = strstr(version_substr, info->version_prefix_2); - if (!version_substr) { - LOG_DEBUG("can't find version_prefix_2 ", info->version_prefix_2); - continue; - } - version_substr += strlen(info->version_prefix_2); - LOG_DEBUG("version_substr: ", version_substr); - } - - int version_value = (version_substr) ? atoi(version_substr) : -1; - LOG_DEBUG("MPI numerical version: ", version_value); - - if (version_value < info->min_version_value) { - LOG_WARN("loaded MPI doesn't match with expected version, " - "consider to switch to ", - info->version_prefix_1, - " ", - (info->version_prefix_2 ? info->version_prefix_2 : ""), - info->min_version_value, - " (min) ", - (info->kind_value ? info->kind_value : ""), - "\n"); - continue; - } - - if (info->kind_prefix && info->kind_value) { - const char* kind_substr = mpi_version; - - if ((kind_substr = strstr(kind_substr, info->kind_prefix))) { - kind_substr += strlen(info->kind_prefix); - while ((isspace(*kind_substr)) && - (kind_substr < (mpi_version + mpi_version_len))) - kind_substr++; - - LOG_DEBUG("kind_substr: ", kind_substr); - - if (strncmp(kind_substr, info->kind_value, strlen(info->kind_value))) { - LOG_WARN("loaded MPI version (", - version_value, - ") ", - "is higher or equal to minimal expected version (", - info->min_version_value, - ") ", - "but kind (", - kind_substr, - ") doesn't match with expected kind (", - info->kind_value, - "), " - "consider to switch to ", - info->version_prefix_1, - " ", - (info->version_prefix_2 ? info->version_prefix_2 : ""), - info->min_version_value, - " (min version) ", - (info->kind_value ? info->kind_value : ""), - "\n"); - } - } - else { - LOG_DEBUG("MPI version is high enough, but kind_prefix (", - info->kind_prefix, - ") can not be found", - " treat this like expected kind (", - info->kind_value, - ") was found"); - } - } - - final_info = info; - LOG_DEBUG("set lib_type = ", - info->name, - " because " - "version (", - version_value, - ") is higher or equal to minimal expected version (", - info->min_version_value, - ")"); - - lib_attr.type = final_info->type; - lib_attr.hmem = (final_info->min_hmem_version_value >= version_value) ? 1 : 0; - - break; - } - } - - if (final_info) { - LOG_DEBUG("MPI library type: ", final_info->name); - } - else { - LOG_DEBUG("MPI library type: none"); - } - - return lib_attr; -} - -size_t atl_mpi_get_ep_count(const atl_attr_t& attr) { - size_t mpi_ep_count = attr.in.ep_count; - if (attr.in.enable_extra_ep) - mpi_ep_count += attr.in.enable_extra_ep; - return mpi_ep_count; -} - -size_t atl_mpi_get_ep_idx(size_t ep_idx) { - size_t mpi_ep_idx = ep_idx; - if (global_data.extra_ep) - mpi_ep_idx += global_data.extra_ep; - return mpi_ep_idx; -} - -/* set these knobs without detection of MPI library type */ -atl_status_t atl_mpi_set_base_env(const atl_attr_t& attr) { - setenv("PSM2_MULTI_EP", "1", 0); - setenv("FI_OFI_RXM_USE_HASH", "0", 0); - -#ifdef CCL_ENABLE_SYCL - setenv("FI_SHM_DISABLE_CMA", "1", 0); -#endif // CCL_ENABLE_SYCL - - setenv("MPIR_CVAR_DEFAULT_THREAD_LEVEL", "MPI_THREAD_MULTIPLE", 0); - - /* request IMPI level append library kind into MPI_Get_library_version output */ - setenv("I_MPI_INFO_LIBRARY_KIND", "1", 0); - - return ATL_STATUS_SUCCESS; -} - -atl_status_t atl_mpi_set_impi_env(const atl_attr_t& attr, const atl_mpi_lib_attr_t& lib_attr) { - char ep_count_str[MPI_MAX_INFO_VAL] = { 0 }; - snprintf(ep_count_str, MPI_MAX_INFO_VAL, "%zu", atl_mpi_get_ep_count(attr)); - - if (attr.in.ep_count) - setenv("I_MPI_OFI_ISEND_INJECT_THRESHOLD", "0", 0); - -#ifdef CCL_ENABLE_SYCL - setenv("I_MPI_SHM_CMA", "0", 0); - if (attr.in.enable_hmem && lib_attr.hmem) { - setenv("I_MPI_OFFLOAD", "2", 0); - setenv("I_MPI_OFFLOAD_TOPOLIB", "l0", 0); - setenv("I_MPI_OFFLOAD_QUEUE_CACHE", "1", 0); - setenv("I_MPI_OFFLOAD_LIST_CACHE", "1", 0); - setenv("I_MPI_OFFLOAD_MEMCPY_KIND", "blocked", 0); - if (attr.in.ep_count > 1) { - /* try to set global lock level before vci level - because setenv is invoked with overwrite=0 */ - setenv("I_MPI_THREAD_LOCK_LEVEL", "global", 0); - } - } -#endif // CCL_ENABLE_SYCL - - setenv("I_MPI_THREAD_SPLIT", "1", 0); - setenv("I_MPI_THREAD_RUNTIME", "generic", 0); - setenv("I_MPI_THREAD_MAX", ep_count_str, 0); - setenv("I_MPI_THREAD_ID_KEY", EP_IDX_KEY, 0); - setenv("I_MPI_THREAD_LOCK_LEVEL", "vci", 0); - - return ATL_STATUS_SUCCESS; -} - -atl_status_t atl_mpi_check_impi_env(const atl_attr_t& attr) { - char* ep_count_env = getenv("I_MPI_THREAD_MAX"); - if (!ep_count_env) - return ATL_STATUS_FAILURE; - if (atoi(ep_count_env) != (int)(atl_mpi_get_ep_count(attr))) - return ATL_STATUS_FAILURE; - - if (!getenv("I_MPI_ROOT")) { - atl_mpi_lib_type_t type = ATL_MPI_LIB_IMPI; - LOG_ERROR("CCL/MPI uses ", - mpi_lib_infos[type].version_prefix_1, - " but I_MPI_ROOT is not set. ", - "Please source ", - mpi_lib_infos[type].kind_value, - " version of ", - mpi_lib_infos[type].version_prefix_1, - " (", - mpi_lib_infos[type].min_version_value, - " or higher version)."); - return ATL_STATUS_FAILURE; - } - - return ATL_STATUS_SUCCESS; -} - -atl_status_t atl_mpi_set_mpich_env(const atl_attr_t& attr) { - char ep_count_str[MPI_MAX_INFO_VAL] = { 0 }; - snprintf(ep_count_str, MPI_MAX_INFO_VAL, "%zu", atl_mpi_get_ep_count(attr)); - - setenv("MPIR_CVAR_CH4_MT_MODEL", "direct", 0); - setenv("MPIR_CVAR_CH4_NUM_VCIS", ep_count_str, 0); - setenv("MPIR_CVAR_CH4_OFI_MAX_VCIS", ep_count_str, 0); - setenv("MPIR_CVAR_CH4_ASYNC_PROGRESS_ID_KEY", EP_IDX_KEY, 0); - setenv("MPIR_CVAR_CH4_OFI_ENABLE_SCALABLE_ENDPOINTS", "1", 0); - - if (attr.in.mnic_type != ATL_MNIC_NONE) { - setenv("MPIR_CVAR_CH4_OFI_ENABLE_NIC_SELECTION", "1", 0); - auto& env = ccl::global_data::env(); - if (env.log_level >= ccl_log_level::info) { - setenv("MPIR_CVAR_CH4_OFI_DUMP_NIC_SETTINGS", "1", 0); - } - } - - setenv("FI_PSM2_DELAY", "0", 0); - setenv("FI_PSM2_TIMEOUT", "0", 0); - setenv("FI_PSM2_NAME_SERVER", "0", 0); - setenv("HFI_NO_CPUAFFINITY", "1", 0); - - return ATL_STATUS_SUCCESS; -} - -atl_status_t atl_mpi_check_mpich_env(const atl_attr_t& attr) { - char* ep_count_env = getenv("MPIR_CVAR_CH4_OFI_MAX_VCIS"); - if (!ep_count_env) - return ATL_STATUS_FAILURE; - if (atoi(ep_count_env) != (int)(atl_mpi_get_ep_count(attr))) - return ATL_STATUS_FAILURE; - return ATL_STATUS_SUCCESS; -} - -atl_status_t atl_mpi_set_env(const atl_attr_t& attr) { - if (global_data.mpi_lib_attr.type != ATL_MPI_LIB_NONE) { - /* library type was already detected and env was set, make sanity check */ - if (global_data.mpi_lib_attr.type == ATL_MPI_LIB_IMPI) { - return atl_mpi_check_impi_env(attr); - } - else if (global_data.mpi_lib_attr.type == ATL_MPI_LIB_MPICH) { - return atl_mpi_check_mpich_env(attr); - } - return ATL_STATUS_SUCCESS; - } - - atl_mpi_set_base_env(attr); - - atl_mpi_lib_attr_t mpi_lib_attr = atl_mpi_get_lib_attr(); - - if (mpi_lib_attr.type == ATL_MPI_LIB_NONE) { - return ATL_STATUS_SUCCESS; - } - - if (mpi_lib_attr.type == ATL_MPI_LIB_IMPI) { - atl_mpi_set_impi_env(attr, mpi_lib_attr); - atl_mpi_check_impi_env(attr); - } - else if (mpi_lib_attr.type == ATL_MPI_LIB_MPICH) { - atl_mpi_set_mpich_env(attr); - atl_mpi_check_mpich_env(attr); - } - - int is_mpi_inited = 0; - MPI_Initialized(&is_mpi_inited); - if (is_mpi_inited) { - LOG_WARN("MPI was initialized externally, CCL-MPI specific environment is ignored"); - } - else { - LOG_DEBUG("set CCL-MPI specific environment"); - } - - global_data.mpi_lib_attr = mpi_lib_attr; - - return ATL_STATUS_SUCCESS; -} - -atl_mpi_comm_info_t atl_mpi_get_comm_info(MPI_Comm comm, const char* key) { - MPI_Info info; - atl_mpi_comm_info_t res; - res.comm = comm; - snprintf(res.key, MPI_MAX_INFO_KEY, "%s", key); - - MPI_Comm_get_info(res.comm, &info); - MPI_Info_get(info, key, MPI_MAX_INFO_VAL, res.value, &res.found); - MPI_Info_free(&info); - - return res; -} - -size_t atl_mpi_get_nic_count(const char* nic_count_key) { - size_t count = 1; - atl_mpi_comm_info_t info = atl_mpi_get_comm_info(MPI_COMM_WORLD, nic_count_key); - CCL_THROW_IF_NOT(info.found, "MPI comm key ", nic_count_key, " was not set"); - - count = atoi(info.value); - if (count <= 0) { - count = 1; - } - - return count; -} - -void atl_mpi_check_comm_info(MPI_Comm comm, const char* key, const char* expected_value) { - atl_mpi_comm_info_t info = atl_mpi_get_comm_info(comm, key); - - CCL_THROW_IF_NOT(info.found, "MPI comm key ", key, " was not set"); - CCL_THROW_IF_NOT(!strcmp(info.value, expected_value), - "MPI comm key ", - key, - ": expected: ", - expected_value, - ", read: ", - info.value); -} - -void atl_mpi_check_comm_ep_idx(MPI_Comm comm, size_t expected_idx) { - if (global_data.mpi_lib_attr.type == ATL_MPI_LIB_NONE) - return; - - char expected_idx_str[MPI_MAX_INFO_VAL] = { 0 }; - snprintf(expected_idx_str, MPI_MAX_INFO_VAL, "%zu", expected_idx); - atl_mpi_check_comm_info(comm, EP_IDX_KEY, expected_idx_str); -} - -void atl_mpi_check_comm_nic_idx(MPI_Comm comm, size_t expected_idx, const char* nic_idx_key) { - char expected_idx_str[MPI_MAX_INFO_VAL] = { 0 }; - snprintf(expected_idx_str, MPI_MAX_INFO_VAL, "%zu", expected_idx); - atl_mpi_check_comm_info(comm, nic_idx_key, expected_idx_str); -} - -#ifdef ENABLE_DEBUG -inline void atl_mpi_check_ep(atl_ep_t* ep) { - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_check_comm_ep_idx(mpi_ep->mpi_comm, atl_mpi_get_ep_idx(ep->idx)); -} -#else -#define atl_mpi_check_ep(ep) -#endif - -static atl_status_t atl_mpi_finalize(atl_ctx_t* ctx) { - int ret = MPI_SUCCESS; - atl_mpi_ctx_t* mpi_ctx = container_of(ctx, atl_mpi_ctx_t, ctx); - atl_ep_t** eps = ctx->eps; - - global_data.ctx_count--; - if (ctx->coord.global_idx == 0) { - LOG_INFO("finalize atl-mpi ctx, remaining ctx_count ", global_data.ctx_count); - } - - int is_mpi_finalized = 0; - MPI_Finalized(&is_mpi_finalized); - - if (!is_mpi_finalized) { - for (size_t i = 0; i < ctx->ep_count; i++) { - atl_mpi_ep_t* mpi_ep = container_of(eps[i], atl_mpi_ep_t, ep); - - if (mpi_ep) { - if (mpi_ctx->progress_mode == ATL_PROGRESS_POLL) { - MPI_Cancel(&(mpi_ep->dummy_req.native_req)); - MPI_Comm_free(&mpi_ep->dummy_comm); - } - MPI_Comm_free(&mpi_ep->mpi_comm); - free(mpi_ep); - } - } - - if (global_data.ctx_count == 0) { - atl_mpi_bf16_finalize(); - atl_mpi_fp16_finalize(); - if (!global_data.is_external_init) { - ret = MPI_Finalize(); - } - else { - LOG_DEBUG("MPI_Init has been called externally, skip MPI_Finalize"); - } - - if (ctx->coord.global_idx == 0) { - LOG_INFO("finalized last atl-mpi ctx"); - } - } - } - else { - for (size_t i = 0; i < ctx->ep_count; i++) { - atl_mpi_ep_t* mpi_ep = container_of(eps[i], atl_mpi_ep_t, ep); - free(mpi_ep); - } - if ((global_data.ctx_count == 0) && (ctx->coord.global_idx == 0)) { - LOG_WARN("MPI_Finalize has been called before CCL finalization"); - } - } - - free(eps); - free(mpi_ctx); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_mr_reg(atl_ctx_t* ctx, const void* buf, size_t len, atl_mr_t** mr) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_mpi_mr_dereg(atl_ctx_t* ctx, atl_mr_t* mr) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_mpi_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) { - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - - int ret = MPI_Isend( - buf, len, MPI_CHAR, dst_proc_idx, (int)tag, mpi_ep->mpi_comm, &mpi_req->native_req); - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) { - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - - int ret = MPI_Irecv( - buf, len, MPI_CHAR, src_proc_idx, (int)tag, mpi_ep->mpi_comm, &mpi_req->native_req); - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) { - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - - int flag = 0, len = 0, ret; - MPI_Status status; - - ret = MPI_Iprobe(src_proc_idx, tag, mpi_ep->mpi_comm, &flag, &status); - if (flag) { - MPI_Get_count(&status, MPI_BYTE, &len); - } - - if (found) - *found = flag; - if (recv_len) - *recv_len = len; - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - if (mpi_ctx->sync_coll) { - ret = MPI_Allgatherv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - send_len, - MPI_CHAR, - recv_buf, - recv_lens, - offsets, - MPI_CHAR, - mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Iallgatherv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - send_len, - MPI_CHAR, - recv_buf, - recv_lens, - offsets, - MPI_CHAR, - mpi_ep->mpi_comm, - &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t count, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); - MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); - - if (mpi_ctx->sync_coll) { - ret = MPI_Allreduce((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - recv_buf, - count, - mpi_dtype, - mpi_op, - mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - //printf("atl_mpi: send_buf %p, recv_buf %p\n", send_buf, recv_buf); - ret = MPI_Iallreduce((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - recv_buf, - count, - mpi_dtype, - mpi_op, - mpi_ep->mpi_comm, - &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - if (mpi_ctx->sync_coll) { - ret = MPI_Alltoall((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - len, - MPI_CHAR, - recv_buf, - len, - MPI_CHAR, - mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Ialltoall((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - len, - MPI_CHAR, - recv_buf, - len, - MPI_CHAR, - mpi_ep->mpi_comm, - &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - if (mpi_ctx->sync_coll) { - ret = MPI_Alltoallv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - send_lens, - send_offsets, - MPI_CHAR, - recv_buf, - recv_lens, - recv_offsets, - MPI_CHAR, - mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Ialltoallv((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - send_lens, - send_offsets, - MPI_CHAR, - recv_buf, - recv_lens, - recv_offsets, - MPI_CHAR, - mpi_ep->mpi_comm, - &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_barrier(atl_ep_t* ep, atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - if (mpi_ctx->sync_coll) { - ret = MPI_Barrier(mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Ibarrier(mpi_ep->mpi_comm, &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - int root, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - if (mpi_ctx->sync_coll) { - ret = MPI_Bcast(buf, len, MPI_CHAR, root, mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Ibcast(buf, len, MPI_CHAR, root, mpi_ep->mpi_comm, &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t count, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - int my_proc_idx = ep->ctx->coord.global_idx; - MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); - MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); - - if (mpi_ctx->sync_coll) { - ret = MPI_Reduce( - (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, - recv_buf, - count, - mpi_dtype, - mpi_op, - root, - mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Ireduce( - (send_buf && (send_buf == recv_buf) && (root == my_proc_idx)) ? MPI_IN_PLACE : send_buf, - recv_buf, - count, - mpi_dtype, - mpi_op, - root, - mpi_ep->mpi_comm, - &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_count, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - int ret = MPI_SUCCESS; - - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - - MPI_Datatype mpi_dtype = atl2mpi_dtype(dtype); - MPI_Op mpi_op = atl2mpi_op(op, mpi_dtype); - - if (mpi_ctx->sync_coll) { - ret = - MPI_Reduce_scatter_block((send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - recv_buf, - recv_count, - mpi_dtype, - mpi_op, - mpi_ep->mpi_comm); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - mpi_req->native_req = MPI_REQUEST_NULL; - } - else { - ret = MPI_Ireduce_scatter_block( - (send_buf && (send_buf == recv_buf)) ? MPI_IN_PLACE : send_buf, - recv_buf, - recv_count, - mpi_dtype, - mpi_op, - mpi_ep->mpi_comm, - &mpi_req->native_req); - mpi_req->comp_state = ATL_MPI_COMP_POSTED; - } - - atl_mpi_check_ep(ep); - - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t r_key, - int dst_proc_idx, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_mpi_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t r_key, - int dst_proc_idx, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -static atl_status_t atl_mpi_ep_wait(atl_ep_t* ep, atl_req_t* req) { - int ret; - MPI_Status status; - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - ret = MPI_Wait(&mpi_req->native_req, &status); - mpi_req->comp_state = ATL_MPI_COMP_COMPLETED; - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_ep_wait_all(atl_ep_t* ep, atl_req_t* reqs, size_t count) { - return ATL_STATUS_UNSUPPORTED; -} - -static inline atl_status_t atl_mpi_ep_progress(atl_ep_t* ep, atl_mpi_req_t* req) { - int flag = 0; - int ret = MPI_Test(&req->native_req, &flag, MPI_STATUS_IGNORE); - - if (flag) { - req->comp_state = ATL_MPI_COMP_COMPLETED; - } - - return RET2ATL(ret); -} - -static inline atl_status_t atl_mpi_ep_poll(atl_ep_t* ep) { - atl_mpi_ctx_t* mpi_ctx = container_of(ep->ctx, atl_mpi_ctx_t, ctx); - if (mpi_ctx->progress_mode == ATL_PROGRESS_POLL) { - atl_mpi_ep_t* mpi_ep = container_of(ep, atl_mpi_ep_t, ep); - atl_mpi_ep_progress(ep, &(mpi_ep->dummy_req)); - } - - return ATL_STATUS_SUCCESS; -} - -static atl_status_t atl_mpi_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) { - CCL_THROW_IF_NOT(is_completed); - - atl_status_t status = ATL_STATUS_SUCCESS; - - atl_mpi_req_t* mpi_req = ((atl_mpi_req_t*)req->internal); - - *is_completed = (mpi_req->comp_state == ATL_MPI_COMP_COMPLETED); - if (*is_completed) { - return ATL_STATUS_SUCCESS; - } - - status = atl_mpi_ep_progress(ep, mpi_req); - *is_completed = (mpi_req->comp_state == ATL_MPI_COMP_COMPLETED); - - return status; -} - -static atl_status_t atl_mpi_ep_init(atl_mpi_ctx_t* mpi_ctx, size_t idx, atl_ep_t** ep) { - int ret; - - ssize_t mpi_ep_idx = atl_mpi_get_ep_idx(idx); - char mpi_ep_idx_str[MPI_MAX_INFO_VAL] = { 0 }; - - size_t nic_idx = 0; - char nic_idx_str[MPI_MAX_INFO_VAL] = { 0 }; - const char* nic_idx_key = - (global_data.mnic_type == ATL_MNIC_GLOBAL) ? GLOBAL_NIC_IDX_KEY : LOCAL_NIC_IDX_KEY; - - atl_mpi_ep_t* mpi_ep = (atl_mpi_ep_t*)calloc(1, sizeof(atl_mpi_ep_t)); - if (!mpi_ep) - return ATL_STATUS_FAILURE; - - ret = MPI_Comm_dup(MPI_COMM_WORLD, &mpi_ep->mpi_comm); - if (ret) - goto err_ep_dup; - - MPI_Info info; - MPI_Info_create(&info); - - /* set EP index */ - snprintf(mpi_ep_idx_str, MPI_MAX_INFO_VAL, "%zu", mpi_ep_idx); - MPI_Info_set(info, EP_IDX_KEY, mpi_ep_idx_str); - - if (global_data.mnic_type != ATL_MNIC_NONE) { - /* set NIC index */ - nic_idx = (idx % global_data.mnic_count); - snprintf(nic_idx_str, MPI_MAX_INFO_VAL, "%zu", nic_idx); - MPI_Info_set(info, nic_idx_key, nic_idx_str); - } - - MPI_Comm_set_info(mpi_ep->mpi_comm, info); - - if (mpi_ctx->progress_mode == ATL_PROGRESS_POLL) { - ret = MPI_Comm_dup(MPI_COMM_WORLD, &mpi_ep->dummy_comm); - if (ret) - goto err_ep_dup; - MPI_Comm_set_info(mpi_ep->dummy_comm, info); - MPI_Irecv(NULL, 0, MPI_CHAR, 0, 0, mpi_ep->dummy_comm, &(mpi_ep->dummy_req.native_req)); - - atl_mpi_check_comm_ep_idx(mpi_ep->dummy_comm, mpi_ep_idx); - if (global_data.mnic_type != ATL_MNIC_NONE) { - atl_mpi_check_comm_nic_idx(mpi_ep->dummy_comm, nic_idx, nic_idx_key); - } - } - - MPI_Info_free(&info); - - atl_mpi_check_comm_ep_idx(mpi_ep->mpi_comm, mpi_ep_idx); - if (global_data.mnic_type != ATL_MNIC_NONE) { - atl_mpi_check_comm_nic_idx(mpi_ep->mpi_comm, nic_idx, nic_idx_key); - } - - LOG_DEBUG("atl-mpi-ep: ", idx, ", ep_idx ", mpi_ep_idx, ", nic_idx ", nic_idx); - - *ep = &mpi_ep->ep; - (*ep)->idx = idx; - (*ep)->ctx = &mpi_ctx->ctx; - - return ATL_STATUS_SUCCESS; - -err_ep_dup: - free(mpi_ep); - return RET2ATL(ret); -} - -static atl_status_t atl_mpi_init(int* argc, - char*** argv, - atl_attr_t* attr, - atl_ctx_t** out_ctx, - const char* main_addr, - ipmi* pmi) { - CCL_THROW_IF_NOT((sizeof(atl_mpi_req_t) <= sizeof(atl_req_t) - offsetof(atl_req_t, internal)), - "unexpected offset: atl_mpi_request size ", - sizeof(atl_mpi_req_t), - ", atl_request size ", - sizeof(atl_req_t), - ", expected offset ", - offsetof(atl_req_t, internal)); - - int ret = MPI_SUCCESS; - size_t i; - int is_tag_ub_set = 0; - void* tag_ub_ptr = NULL; - int required_thread_level = MPI_THREAD_MULTIPLE, provided_thread_level; - - char my_hostname[ATL_MAX_HOSTNAME_LEN] = { 0 }; - - atl_mpi_ctx_t* mpi_ctx = (atl_mpi_ctx_t*)calloc(1, sizeof(atl_mpi_ctx_t)); - if (!mpi_ctx) - return ATL_STATUS_FAILURE; - - atl_ctx_t* ctx = &(mpi_ctx->ctx); - - if (global_data.ctx_count == 0) { - if (atl_mpi_set_env(*attr)) { - goto err_init; - } - - MPI_Initialized(&global_data.is_external_init); - - if (!global_data.is_external_init) { - ret = MPI_Init_thread(argc, argv, required_thread_level, &provided_thread_level); - if (provided_thread_level < required_thread_level) { - LOG_ERROR("unexpected MPI thread level: required ", - required_thread_level, - ", provided ", - provided_thread_level); - goto err_init; - } - } - else { - LOG_DEBUG("MPI was initialized externaly"); - MPI_Query_thread(&provided_thread_level); - if (provided_thread_level < required_thread_level) { - LOG_WARN("MPI was initialized externaly but with unexpected thread level: " - "required ", - required_thread_level, - ", provided ", - provided_thread_level); - } - } - - if (ret) - goto err_init; - - if (global_data.mpi_lib_attr.type == ATL_MPI_LIB_NONE) - global_data.mpi_lib_attr = atl_mpi_get_lib_attr(); - - global_data.extra_ep = attr->in.enable_extra_ep; - - global_data.mnic_type = attr->in.mnic_type; - if (global_data.mpi_lib_attr.type != ATL_MPI_LIB_MPICH) { - /* only MPICH supports multi-NIC */ - global_data.mnic_type = ATL_MNIC_NONE; - } - - if (global_data.mnic_type == ATL_MNIC_LOCAL) { - global_data.mnic_count = atl_mpi_get_nic_count(LOCAL_NIC_COUNT_KEY); - } - else if (global_data.mnic_type == ATL_MNIC_GLOBAL) { - global_data.mnic_count = atl_mpi_get_nic_count(GLOBAL_NIC_IDX_KEY); - } - else if (global_data.mnic_type == ATL_MNIC_NONE) { - global_data.mnic_count = 1; - } - global_data.mnic_count = std::min(global_data.mnic_count, attr->in.mnic_count); - global_data.mnic_count = std::min(global_data.mnic_count, attr->in.ep_count); - global_data.mnic_count = std::max(global_data.mnic_count, (size_t)(1)); - - if (atl_mpi_bf16_init() == ATL_STATUS_FAILURE) { - atl_mpi_bf16_finalize(); - goto err_init; - } - - if (atl_mpi_fp16_init() == ATL_STATUS_FAILURE) { - atl_mpi_fp16_finalize(); - goto err_init; - } - } - global_data.ctx_count++; - - atl_proc_coord_t* coord; - coord = &(ctx->coord); - - MPI_Comm_rank(MPI_COMM_WORLD, (int*)&(coord->global_idx)); - MPI_Comm_size(MPI_COMM_WORLD, (int*)&(coord->global_count)); - - MPI_Comm local_comm; - MPI_Comm_split_type( - MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, coord->global_count, MPI_INFO_NULL, &local_comm); - MPI_Comm_rank(local_comm, (int*)&(coord->local_idx)); - MPI_Comm_size(local_comm, (int*)&(coord->local_count)); - MPI_Comm_free(&local_comm); - - gethostname(my_hostname, ATL_MAX_HOSTNAME_LEN - 1); - coord->hostname_hash = std::hash{}(my_hostname); - - ctx->ep_count = attr->in.ep_count; - ctx->eps = (atl_ep_t**)calloc(1, sizeof(void*) * attr->in.ep_count); - if (!ctx->eps) - goto err_after_init; - - char* progress_mode_env; - progress_mode_env = getenv(ATL_PROGRESS_MODE_ENV); - if (progress_mode_env) { - mpi_ctx->progress_mode = (atl_progress_mode_t)atoi(progress_mode_env); - } - else { - mpi_ctx->progress_mode = ATL_PROGRESS_CHECK; - } - mpi_ctx->sync_coll = attr->in.enable_sync_coll; - - if (coord->global_idx == 0) { - if (global_data.ctx_count == 1) { - LOG_INFO("atl-mpi-global:") - LOG_INFO(" is_external_init: ", global_data.is_external_init); - LOG_INFO(" mpi_lib_attr.type: ", mpi_lib_infos[global_data.mpi_lib_attr.type].name); - LOG_INFO(" mpi_lib_attr.hmem: ", global_data.mpi_lib_attr.hmem); - LOG_INFO(" extra_ep: ", global_data.extra_ep); - LOG_INFO(" mnic_type: ", global_data.mnic_type); - if (global_data.mnic_type != ATL_MNIC_NONE) - LOG_INFO(" mnic_count: ", global_data.mnic_count); - } - LOG_INFO("atl-mpi-ctx: ", (global_data.ctx_count - 1)); - LOG_INFO(" progress_mode: ", mpi_ctx->progress_mode); - LOG_INFO(" sync_coll: ", mpi_ctx->sync_coll); - } - - for (i = 0; i < attr->in.ep_count; i++) { - ret = atl_mpi_ep_init(mpi_ctx, i, &(ctx->eps[i])); - if (ret) - goto err_ep_dup; - } - - *out_ctx = &mpi_ctx->ctx; - - MPI_Comm_get_attr(MPI_COMM_WORLD, MPI_TAG_UB, &tag_ub_ptr, &is_tag_ub_set); - - /* report actual attributes back to upper level */ - attr->out.enable_shm = 0; - attr->out.enable_rma = 0; - attr->out.enable_hmem = attr->in.enable_hmem & global_data.mpi_lib_attr.hmem; - attr->out.mnic_type = global_data.mnic_type; - attr->out.mnic_count = global_data.mnic_count; - attr->out.tag_bits = 32; - attr->out.max_tag = (is_tag_ub_set) ? *((int*)tag_ub_ptr) : 0; - attr->out.max_order_waw_size = 0; - - return ATL_STATUS_SUCCESS; - -err_ep_dup: - for (i = 0; i < attr->in.ep_count; i++) { - atl_mpi_ep_t* mpi_ep = container_of(ctx->eps[i], atl_mpi_ep_t, ep); - - if (ctx->eps[i] && mpi_ep) { - if (mpi_ctx->progress_mode == ATL_PROGRESS_POLL) { - MPI_Cancel(&(mpi_ep->dummy_req.native_req)); - MPI_Comm_free(&mpi_ep->dummy_comm); - } - MPI_Comm_free(&mpi_ep->mpi_comm); - } - } - free(ctx->eps); - -err_after_init: - global_data.ctx_count--; - if (global_data.ctx_count == 0) { - atl_mpi_bf16_finalize(); - atl_mpi_fp16_finalize(); - if (!global_data.is_external_init) { - MPI_Finalize(); - } - } - -err_init: - free(mpi_ctx); - return ATL_STATUS_FAILURE; -} - -atl_status_t atl_mpi_main_addr_reserve(char* main_addr) { - return ATL_STATUS_UNSUPPORTED; -} - -#endif // CCL_ENABLE_MPI diff --git a/src/atl/ofi/atl_ofi.cpp b/src/atl/ofi/atl_ofi.cpp index 21bc2cb6f..bfe0a6fce 100644 --- a/src/atl/ofi/atl_ofi.cpp +++ b/src/atl/ofi/atl_ofi.cpp @@ -82,8 +82,11 @@ void atl_ofi::mr_cache::get(fid_domain* domain, void* buf, size_t bytes, fid_mr* } } - struct fi_mr_attr mr_attr = {}; - struct iovec iov = {}; + struct fi_mr_attr mr_attr; + struct iovec iov; + + memset(&mr_attr, 0, sizeof(mr_attr)); + memset(&iov, 0, sizeof(iov)); iov.iov_base = buf; iov.iov_len = bytes; @@ -114,7 +117,7 @@ void atl_ofi::mr_cache::get(fid_domain* domain, void* buf, size_t bytes, fid_mr* ZE_CALL(zeDeviceGetProperties, (alloc_dev, &alloc_dev_props)); int dev_idx = -1; - for (int idx = 0; idx < ze_data.device_count; idx++) { + for (int idx = 0; idx < static_cast(ze_data.device_count); idx++) { ze_device_properties_t dev_props = ccl::ze::default_device_props; ZE_CALL(zeDeviceGetProperties, (ze_data.devices[idx], &dev_props)); @@ -163,11 +166,11 @@ atl_status_t atl_ofi::atl_set_env(const atl_attr_t& attr) { return atl_ofi_set_env(attr); } -atl_status_t atl_ofi::atl_init(int* argc, - char*** argv, - atl_attr_t* attr, - const char* main_addr, - std::unique_ptr& pmi) { +atl_status_t atl_ofi::init(int* argc, + char*** argv, + atl_attr_t* attr, + const char* main_addr, + std::shared_ptr pmi) { inited = true; struct fi_info *prov_list = nullptr, *base_hints = nullptr, *prov_hints = nullptr; int fi_version; @@ -191,10 +194,7 @@ atl_status_t atl_ofi::atl_init(int* argc, if (global_data.ctx_count == 0) { ret = atl_ofi_set_env(*attr); - if (ret != ATL_STATUS_SUCCESS) { - LOG_ERROR("atl_ofi_set_env error"); - return ATL_STATUS_FAILURE; - } + ATL_CHECK_STATUS(ret, "atl_ofi_set_env error"); fi_version_env = getenv(ATL_OFI_MAJOR_VERSION); if (fi_version_env) { @@ -222,9 +222,7 @@ atl_status_t atl_ofi::atl_init(int* argc, ctx = &(ofi_ctx->ctx); ctx->ep_count = attr->in.ep_count; - ctx->eps = (atl_ep**)calloc(1, sizeof(void*) * attr->in.ep_count); - if (!ctx->eps) - goto err; + eps.resize(attr->in.ep_count); ctx->coord.global_count = pmi->get_size(); ctx->coord.global_idx = pmi->get_rank(); @@ -325,7 +323,6 @@ atl_status_t atl_ofi::atl_init(int* argc, ofi_ctx->mnic_type = attr->in.mnic_type; ATL_CALL(atl_ofi_parse_mnic_name(ctx, attr->in.mnic_name), goto err); ofi_ctx->mnic_count = std::min(attr->in.mnic_count, (size_t)(ATL_OFI_MAX_NW_PROV_COUNT)); - ofi_ctx->mnic_count = std::min(ofi_ctx->mnic_count, attr->in.ep_count); ofi_ctx->mnic_count = std::max(ofi_ctx->mnic_count, (size_t)(1)); if ((ofi_ctx->mnic_type != ATL_MNIC_NONE) && @@ -337,6 +334,8 @@ atl_status_t atl_ofi::atl_init(int* argc, if (ofi_ctx->mnic_type == ATL_MNIC_NONE) ofi_ctx->mnic_count = 1; + ofi_ctx->mnic_offset = attr->in.mnic_offset; + attr->out.tag_bits = 64; attr->out.max_tag = 0xFFFFFFFFFFFFFFFF; @@ -397,10 +396,10 @@ atl_status_t atl_ofi::atl_init(int* argc, LOG_INFO("ep_idx: ", ep_idx, ", active_prov_idxs: ", ss.str()); } - ctx->eps[ep_idx] = ep; + eps[ep_idx] = ep; } - pmi->pmrt_barrier(); + ATL_CHECK_STATUS(pmi->pmrt_barrier(), "barrier failed"); max_retry_count_env = getenv(ATL_OFI_MAX_RETRY_COUNT_ENV); if (max_retry_count_env) { @@ -428,10 +427,11 @@ atl_status_t atl_ofi::atl_init(int* argc, LOG_INFO(" prov_count: ", ofi_ctx->prov_count); LOG_INFO(" nw_prov_count: ", ofi_ctx->nw_prov_count); LOG_INFO(" nw_prov_first_idx: ", ofi_ctx->nw_prov_first_idx); - LOG_INFO(" mnic_type: ", ofi_ctx->mnic_type); + LOG_INFO(" mnic_type: ", to_string(ofi_ctx->mnic_type)); LOG_INFO(" mnic_include_names: ", vec_to_string(ofi_ctx->mnic_include_names)); LOG_INFO(" mnic_exclude_names: ", vec_to_string(ofi_ctx->mnic_exclude_names)); LOG_INFO(" mnic_count: ", ofi_ctx->mnic_count); + LOG_INFO(" mnic_offset: ", to_string(ofi_ctx->mnic_offset)); LOG_INFO(" max_retry_count: ", ofi_ctx->max_retry_count); LOG_INFO(" progress_mode: ", ofi_ctx->progress_mode); #ifdef CCL_ENABLE_OFI_HMEM @@ -468,12 +468,12 @@ atl_status_t atl_ofi::atl_init(int* argc, } if (ctx != nullptr) - atl_finalize(); + finalize(); return ATL_STATUS_FAILURE; } -atl_status_t atl_ofi::atl_finalize() { +atl_status_t atl_ofi::finalize() { is_finalized = true; int ret = 0; size_t idx; @@ -493,13 +493,14 @@ atl_status_t atl_ofi::atl_finalize() { } for (idx = 0; idx < ctx->ep_count; idx++) { - atl_ofi_ep_t* ofi_ep = container_of(ctx->eps[idx], atl_ofi_ep_t, ep); + atl_ofi_ep_t* ofi_ep = container_of(eps[idx], atl_ofi_ep_t, ep); free(ofi_ep); } if (global_data.ctx_count == 0) { if (global_data.dlhandle) { dlclose(global_data.dlhandle); + global_data.dlhandle = nullptr; } if (ctx->coord.global_idx == 0) { @@ -507,20 +508,19 @@ atl_status_t atl_ofi::atl_finalize() { } } - free(ctx->eps); free(ofi_ctx); return RET2ATL(ret); } -atl_status_t atl_ofi::atl_update(std::unique_ptr& pmi) { +atl_status_t atl_ofi::update(std::shared_ptr pmi) { int ret; size_t prov_idx; atl_ofi_ctx_t* ofi_ctx; ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); - pmi->pmrt_barrier(); + ATL_CHECK_STATUS(pmi->pmrt_barrier(), "barrier failed"); atl_ofi_reset(ctx); memset(&(ctx->coord), 0, sizeof(atl_proc_coord_t)); @@ -555,21 +555,21 @@ atl_status_t atl_ofi::atl_update(std::unique_ptr& pmi) { return RET2ATL(ret); } - pmi->pmrt_barrier(); + ATL_CHECK_STATUS(pmi->pmrt_barrier(), "barrier failed"); /* normal end of execution */ return RET2ATL(ret); } -atl_ep_t** atl_ofi::atl_get_eps() { - return ctx->eps; +std::vector atl_ofi::get_eps() { + return eps; } -atl_proc_coord_t* atl_ofi::atl_get_proc_coord() { +atl_proc_coord_t* atl_ofi::get_proc_coord() { return &(ctx->coord); } -atl_status_t atl_ofi::atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) { +atl_status_t atl_ofi::mr_reg(const void* buf, size_t len, atl_mr_t** mr) { int ret; atl_ofi_ctx_t* ofi_ctx; ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); @@ -605,7 +605,7 @@ atl_status_t atl_ofi::atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) { return ATL_STATUS_FAILURE; } -atl_status_t atl_ofi::atl_mr_dereg(atl_mr_t* mr) { +atl_status_t atl_ofi::mr_dereg(atl_mr_t* mr) { atl_ofi_mr_t* ofi_mr; ofi_mr = container_of(mr, atl_ofi_mr_t, mr); int ret = fi_close(&ofi_mr->fi_mr->fid); @@ -613,12 +613,12 @@ atl_status_t atl_ofi::atl_mr_dereg(atl_mr_t* mr) { return RET2ATL(ret); } -atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) { +atl_status_t atl_ofi::send(atl_ep_t* ep, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req) { ssize_t ret; atl_ofi_prov_t* prov; @@ -627,14 +627,10 @@ atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, prov = atl_ofi_get_prov(ep, dst_proc_idx, len); prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - req->tag = tag; - req->remote_proc_idx = dst_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; + atl_ofi_init_req(req, prov_ep, prov_ep->tx); - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->tx; + ofi_req = ((atl_ofi_req_t*)req->internal); cache.get(ep->idx, prov->domain, const_cast(buf), len, &ofi_req->mr); void* desc = (ofi_req->mr) ? fi_mr_desc(ofi_req->mr) : nullptr; @@ -658,12 +654,12 @@ atl_status_t atl_ofi::atl_ep_send(atl_ep_t* ep, return RET2ATL(ret); } -atl_status_t atl_ofi::atl_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) { +atl_status_t atl_ofi::recv(atl_ep_t* ep, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req) { ssize_t ret; atl_ofi_prov_t* prov; @@ -672,14 +668,10 @@ atl_status_t atl_ofi::atl_ep_recv(atl_ep_t* ep, prov = atl_ofi_get_prov(ep, src_proc_idx, len); prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - req->tag = tag; - req->remote_proc_idx = src_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; + atl_ofi_init_req(req, prov_ep, prov_ep->rx); - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->rx; + ofi_req = ((atl_ofi_req_t*)req->internal); cache.get(ep->idx, prov->domain, const_cast(buf), len, &ofi_req->mr); void* desc = (ofi_req->mr) ? fi_mr_desc(ofi_req->mr) : nullptr; @@ -703,11 +695,11 @@ atl_status_t atl_ofi::atl_ep_recv(atl_ep_t* ep, return RET2ATL(ret); } -atl_status_t atl_ofi::atl_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) { +atl_status_t atl_ofi::probe(atl_ep_t* ep, + int src_proc_idx, + uint64_t tag, + int* found, + size_t* recv_len) { CCL_THROW("unexpected path"); atl_status_t ret; @@ -763,7 +755,7 @@ atl_status_t atl_ofi::atl_ep_probe(atl_ep_t* ep, } do { - ret = atl_ep_poll(ep); + ret = poll(ep); if (ret != ATL_STATUS_SUCCESS) return ret; @@ -814,82 +806,14 @@ atl_status_t atl_ofi::atl_ep_probe(atl_ep_t* ep, return RET2ATL(ofi_ret); } -atl_status_t atl_ofi::atl_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - int len, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_bcast(atl_ep_t* ep, void* buf, size_t len, int root, atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) { - return ATL_STATUS_UNSUPPORTED; -} - -atl_status_t atl_ofi::atl_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { +atl_status_t atl_ofi::read(atl_ep_t* ep, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) { ssize_t ret; atl_ofi_prov_t* prov; @@ -898,14 +822,10 @@ atl_status_t atl_ofi::atl_ep_read(atl_ep_t* ep, prov = atl_ofi_get_prov(ep, dst_proc_idx, len); prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - req->tag = 0; - req->remote_proc_idx = dst_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; + atl_ofi_init_req(req, prov_ep, prov_ep->tx); - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->tx; + ofi_req = ((atl_ofi_req_t*)req->internal); ATL_OFI_RETRY(fi_read(prov_ep->tx, buf, @@ -920,14 +840,14 @@ atl_status_t atl_ofi::atl_ep_read(atl_ep_t* ep, return RET2ATL(ret); } -atl_status_t atl_ofi::atl_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) { +atl_status_t atl_ofi::write(atl_ep_t* ep, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) { ssize_t ret; atl_ofi_prov_t* prov; @@ -936,14 +856,10 @@ atl_status_t atl_ofi::atl_ep_write(atl_ep_t* ep, prov = atl_ofi_get_prov(ep, dst_proc_idx, len); prov_ep = &(prov->eps[ep->idx]); - ofi_req = ((atl_ofi_req_t*)req->internal); - req->tag = 0; - req->remote_proc_idx = dst_proc_idx; - ofi_req->comp_state = ATL_OFI_COMP_POSTED; + atl_ofi_init_req(req, prov_ep, prov_ep->tx); - ofi_req->prov_ep = prov_ep; - ofi_req->fi_ep = prov_ep->tx; + ofi_req = ((atl_ofi_req_t*)req->internal); ATL_OFI_RETRY(fi_write(prov_ep->tx, buf, @@ -958,7 +874,7 @@ atl_status_t atl_ofi::atl_ep_write(atl_ep_t* ep, return RET2ATL(ret); } -atl_status_t atl_ofi::atl_ep_wait(atl_ep_t* ep, atl_req_t* req) { +atl_status_t atl_ofi::wait(atl_ep_t* ep, atl_req_t* req) { atl_status_t ret; atl_ofi_req_t* ofi_req; @@ -966,18 +882,18 @@ atl_status_t atl_ofi::atl_ep_wait(atl_ep_t* ep, atl_req_t* req) { ofi_req = ((atl_ofi_req_t*)req->internal); while ((ofi_req->comp_state != ATL_OFI_COMP_COMPLETED) && - ((ret = atl_ep_poll(ep)) == ATL_STATUS_SUCCESS)) + ((ret = poll(ep)) == ATL_STATUS_SUCCESS)) ; return ret; } -atl_status_t atl_ofi::atl_ep_wait_all(atl_ep_t* ep, atl_req_t* reqs, size_t count) { +atl_status_t atl_ofi::wait_all(atl_ep_t* ep, atl_req_t* reqs, size_t count) { size_t i; atl_status_t ret; for (i = 0; i < count; i++) { - ret = atl_ep_wait(ep, &reqs[i]); + ret = wait(ep, &reqs[i]); if (ret != ATL_STATUS_SUCCESS) return ret; } @@ -985,7 +901,7 @@ atl_status_t atl_ofi::atl_ep_wait_all(atl_ep_t* ep, atl_req_t* reqs, size_t coun return ATL_STATUS_SUCCESS; } -atl_status_t atl_ofi::atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) { +atl_status_t atl_ofi::cancel(atl_ep_t* ep, atl_req_t* req) { int ret; atl_ofi_req_t* ofi_req; @@ -1000,7 +916,7 @@ atl_status_t atl_ofi::atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) { return ATL_STATUS_SUCCESS; } -atl_status_t atl_ofi::atl_ep_poll(atl_ep_t* ep) { +atl_status_t atl_ofi::poll(atl_ep_t* ep) { atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); if (ofi_ctx->progress_mode == ATL_PROGRESS_POLL) { atl_ep_progress(ep); @@ -1008,9 +924,7 @@ atl_status_t atl_ofi::atl_ep_poll(atl_ep_t* ep) { return ATL_STATUS_SUCCESS; } -atl_status_t atl_ofi::atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) { - CCL_THROW_IF_NOT(is_completed); - +atl_status_t atl_ofi::check(atl_ep_t* ep, atl_req_t* req) { atl_status_t status; atl_ofi_req_t* ofi_req; atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); @@ -1018,14 +932,16 @@ atl_status_t atl_ofi::atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* r status = ATL_STATUS_SUCCESS; ofi_req = ((atl_ofi_req_t*)req->internal); - *is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); - if (*is_completed) { + CCL_THROW_IF_NOT(!req->is_completed, "request is already completed"); + + req->is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); + if (req->is_completed) { return ATL_STATUS_SUCCESS; } if (ofi_ctx->progress_mode == ATL_PROGRESS_CHECK) { status = atl_ep_progress(ep); - *is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); + req->is_completed = (ofi_req->comp_state == ATL_OFI_COMP_COMPLETED); } return status; @@ -1033,7 +949,7 @@ atl_status_t atl_ofi::atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* r atl_ofi::~atl_ofi() { if (!is_finalized) { - atl_finalize(); + finalize(); } } diff --git a/src/atl/ofi/atl_ofi.hpp b/src/atl/ofi/atl_ofi.hpp index a06a45648..4d7b44079 100644 --- a/src/atl/ofi/atl_ofi.hpp +++ b/src/atl/ofi/atl_ofi.hpp @@ -13,146 +13,86 @@ See the License for the specific language governing permissions and limitations under the License. */ +#pragma once #include #include #include #include -#include "atl.h" #include "atl_ofi_helper.hpp" #include "common/utils/hash.hpp" -class atl_ofi final : public iatl { +class atl_ofi { public: atl_ofi() = default; - ~atl_ofi() override; + ~atl_ofi(); static atl_status_t atl_set_env(const atl_attr_t& attr); - atl_status_t atl_init(int* argc, - char*** argv, - atl_attr_t* attr, - const char* main_addr, - std::unique_ptr& pmi) override; - - atl_status_t atl_update(std::unique_ptr& pmi) override; - - atl_ep_t** atl_get_eps() override; - - atl_proc_coord_t* atl_get_proc_coord() override; - - atl_status_t atl_mr_reg(const void* buf, size_t len, atl_mr_t** mr) override; - - atl_status_t atl_mr_dereg(atl_mr_t* mr) override; - - atl_status_t atl_ep_send(atl_ep_t* ep, - const void* buf, - size_t len, - int dst_proc_idx, - uint64_t tag, - atl_req_t* req) override; - - atl_status_t atl_ep_recv(atl_ep_t* ep, - void* buf, - size_t len, - int src_proc_idx, - uint64_t tag, - atl_req_t* req) override; - - atl_status_t atl_ep_probe(atl_ep_t* ep, - int src_proc_idx, - uint64_t tag, - int* found, - size_t* recv_len) override; - - atl_status_t atl_ep_allgatherv(atl_ep_t* ep, - const void* send_buf, - size_t send_len, - void* recv_buf, - const int* recv_lens, - const int* offsets, - atl_req_t* req) override; - - atl_status_t atl_ep_allreduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) override; - - atl_status_t atl_ep_alltoall(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - int len, - atl_req_t* req) override; - - atl_status_t atl_ep_alltoallv(atl_ep_t* ep, - const void* send_buf, - const int* send_lens, - const int* send_offsets, - void* recv_buf, - const int* recv_lens, - const int* recv_offsets, - atl_req_t* req) override; - - atl_status_t atl_ep_barrier(atl_ep_t* ep, atl_req_t* req) override; - - atl_status_t atl_ep_bcast(atl_ep_t* ep, - void* buf, - size_t len, - int root, - atl_req_t* req) override; - - atl_status_t atl_ep_reduce(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t len, - int root, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) override; - - atl_status_t atl_ep_reduce_scatter(atl_ep_t* ep, - const void* send_buf, - void* recv_buf, - size_t recv_len, - atl_datatype_t dtype, - atl_reduction_t op, - atl_req_t* req) override; - - atl_status_t atl_ep_read(atl_ep_t* ep, - void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) override; - - atl_status_t atl_ep_write(atl_ep_t* ep, - const void* buf, - size_t len, - atl_mr_t* mr, - uint64_t addr, - uintptr_t remote_key, - int dst_proc_idx, - atl_req_t* req) override; - - atl_status_t atl_ep_wait(atl_ep_t* ep, atl_req_t* req) override; - - atl_status_t atl_ep_wait_all(atl_ep_t* ep, atl_req_t* req, size_t count) override; - - atl_status_t atl_ep_cancel(atl_ep_t* ep, atl_req_t* req) override; - - atl_status_t atl_ep_poll(atl_ep_t* ep) override; - - atl_status_t atl_ep_check(atl_ep_t* ep, int* is_completed, atl_req_t* req) override; - - atl_status_t atl_finalize() override; - - bool is_inited() override { + atl_status_t init(int* argc, + char*** argv, + atl_attr_t* attr, + const char* main_addr, + std::shared_ptr pmi); + + atl_status_t update(std::shared_ptr pmi); + + std::vector get_eps(); + + atl_proc_coord_t* get_proc_coord(); + + atl_status_t mr_reg(const void* buf, size_t len, atl_mr_t** mr); + + atl_status_t mr_dereg(atl_mr_t* mr); + + atl_status_t send(atl_ep_t* ep, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req); + + atl_status_t recv(atl_ep_t* ep, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req); + + atl_status_t probe(atl_ep_t* ep, int src_proc_idx, uint64_t tag, int* found, size_t* recv_len); + + atl_status_t read(atl_ep_t* ep, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req); + + atl_status_t write(atl_ep_t* ep, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req); + + atl_status_t wait(atl_ep_t* ep, atl_req_t* req); + + atl_status_t wait_all(atl_ep_t* ep, atl_req_t* req, size_t count); + + atl_status_t cancel(atl_ep_t* ep, atl_req_t* req); + + atl_status_t poll(atl_ep_t* ep); + + atl_status_t check(atl_ep_t* ep, atl_req_t* req); + + atl_status_t finalize(); + + bool is_inited() { return inited; } @@ -162,6 +102,7 @@ class atl_ofi final : public iatl { atl_status_t atl_prov_ep_handle_cq_err(atl_ofi_prov_ep_t* ep); atl_ctx_t* ctx = nullptr; + std::vector eps; class mr_cache { public: diff --git a/src/atl/ofi/atl_ofi_comm.cpp b/src/atl/ofi/atl_ofi_comm.cpp new file mode 100644 index 000000000..63a86f063 --- /dev/null +++ b/src/atl/ofi/atl_ofi_comm.cpp @@ -0,0 +1,226 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "atl/ofi/atl_ofi_comm.hpp" +#include "atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h" +#include "atl/util/pm/pmi_rt/pmi_simple.h" +#include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h" +#include "atl/util/pm/pmi_resizable_rt/pmi_resizable.h" +#include "atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.h" +#include "atl/ofi/atl_ofi.hpp" +#include "exec/exec.hpp" + +std::atomic atl_ofi_comm::comm_count{ 0 }; +atl_ofi* atl_ofi_comm::transport{ nullptr }; + +atl_ofi_comm::~atl_ofi_comm() { + static std::mutex memory_mutex; + std::lock_guard lock(memory_mutex); + tag.reset(); + comm_count--; + if (comm_count.load() == 0) { + delete transport; + transport = nullptr; + } +} + +atl_ofi_comm::atl_ofi_comm() { + char* pm_type_str = getenv(PM_TYPE); + + if (pm_type_str) { + if (strstr(pm_type_str, PM_RT_VAL_SIMPLE)) { + pmi = std::shared_ptr(new pmi_simple()); + } + else if (strstr(pm_type_str, PM_RT_VAL_RESIZABLE)) { + std::shared_ptr k(new internal_kvs()); + pmi = std::shared_ptr(new pmi_resizable(k)); + } + else { + LOG_ERROR("Unknown %s: %s\n", PM_TYPE, pm_type_str); + } + } + else { + pmi = std::shared_ptr(new pmi_simple()); + } + + CCL_THROW_IF_NOT(init_transport(true) == ATL_STATUS_SUCCESS, "init transport failed"); +} + +atl_ofi_comm::atl_ofi_comm(std::shared_ptr k) { + char* pm_type_str = getenv(PM_TYPE); + + if (pm_type_str) { + if (strstr(pm_type_str, PM_RT_VAL_SIMPLE)) { + pmi = std::shared_ptr(new pmi_simple()); + } + else if (strstr(pm_type_str, PM_RT_VAL_RESIZABLE)) { + pmi = std::shared_ptr(new pmi_resizable(k)); + } + else { + LOG_ERROR("Unknown %s: %s\n", PM_TYPE, pm_type_str); + } + } + else { + pmi = std::shared_ptr(new pmi_simple()); + } + + CCL_THROW_IF_NOT(init_transport(true) == ATL_STATUS_SUCCESS, "init transport failed"); +} + +atl_ofi_comm::atl_ofi_comm(int total_rank_count, + const std::vector& ranks, + std::shared_ptr k) { + std::shared_ptr kvs; + if ((kvs = std::dynamic_pointer_cast(k)) != nullptr) { + pmi = + std::shared_ptr(new pmi_resizable_simple_internal(total_rank_count, ranks, kvs)); + } + else { + pmi = std::shared_ptr(new pmi_resizable_simple(total_rank_count, ranks, k)); + } + + CCL_THROW_IF_NOT(init_transport(true) == ATL_STATUS_SUCCESS, "init transport failed"); +} +atl_status_t atl_ofi_comm::init_transport(bool is_new) { + LOG_DEBUG("init ATL, requested ep_count ", attr.in.ep_count); + + if (is_new) { + ATL_CHECK_STATUS(pmi->pmrt_init(), "pmi init failed"); + static std::mutex memory_mutex; + { + std::lock_guard lock(memory_mutex); + if (!transport) { + transport = new atl_ofi(); + } + if (!transport->is_inited()) { + CCL_THROW_IF_NOT( + transport->init(nullptr, nullptr, &attr, nullptr, pmi) == ATL_STATUS_SUCCESS, + "failed to initialize ATL"); + + if (pmi->get_rank() == 0) { + print_atl_attrs(); + } + } + } + eps = transport->get_eps(); + coord = *transport->get_proc_coord(); + + parent_rank = rank = pmi->get_rank(); + parent_size = size = pmi->get_size(); + + rank2rank_map.resize(size); + for (int i = 0; i < size; i++) { + rank2rank_map[i] = i; + } + } + + threads_per_process = pmi->get_threads_per_process(); + ranks_per_process = pmi->get_ranks_per_process(); + comm_count++; + + init_tag(); + + if (pmi->get_local_thread_idx() == 0) { + executor_update(); + } + + return ATL_STATUS_SUCCESS; +} + +std::shared_ptr atl_ofi_comm::comm_split(int color) { + return std::shared_ptr(new atl_ofi_comm(this, color)); +} + +atl_ofi_comm::atl_ofi_comm(atl_ofi_comm* parent, int color) { + eps = parent->eps; + parent_size = parent->parent_size; + parent_rank = parent->parent_rank; + pmi = parent->pmi; + + coord.hostname_hash = transport->get_proc_coord()->hostname_hash; + coord.local_count = 0; + coord.local_idx = 0; + + std::vector ranks_info(parent_size); + rank_info_t rank_info{ color, parent_rank, coord.hostname_hash }; + parent->rank_info_exchange(ranks_info, rank_info); + + size = 0; + for (auto& it : ranks_info) { + int recv_color; + int recv_rank; + size_t recv_hash; + std::tie(recv_color, recv_rank, recv_hash) = it; + if (recv_color == color) { + rank2rank_map.push_back(recv_rank); + + if (recv_hash == coord.hostname_hash) { + coord.local_count++; + } + if (recv_rank == parent_rank) { + coord.global_idx = rank = size; + coord.local_idx = coord.local_count; + } + size++; + } + } + coord.global_count = size; + CCL_THROW_IF_NOT(init_transport(false) == ATL_STATUS_SUCCESS, "init transport failed"); +} + +void atl_ofi_comm::rank_info_exchange(std::vector& ranks_info, rank_info_t rank_info) { + std::vector send_reqs(size - 1); + std::vector recv_reqs(size - 1); + const size_t ep_idx = 0; + + for (int i = 0, j = 0; i < size; i++) { + if (i == rank) + continue; + atl_status_t ret; + do { + ret = + send(ep_idx, &rank_info, sizeof(rank_info_t), i, rank * 1000 + i, &(send_reqs[j])); + CCL_THROW_IF_NOT(ret != ATL_STATUS_FAILURE, "send failed"); + ccl_yield(ccl::global_data::env().yield_type); + } while (ret == ATL_STATUS_AGAIN); + + do { + ret = recv( + ep_idx, &ranks_info[i], sizeof(rank_info_t), i, i * 1000 + rank, &(recv_reqs[j])); + CCL_THROW_IF_NOT(ret != ATL_STATUS_FAILURE, "recv failed"); + ccl_yield(ccl::global_data::env().yield_type); + } while (ret == ATL_STATUS_AGAIN); + j++; + } + + ranks_info[rank] = rank_info; + bool is_completed = false; + while (!is_completed) { + is_completed = true; + poll(ep_idx); + for (size_t i = 0; i < send_reqs.size(); i++) { + if (!send_reqs[i].is_completed) { + CCL_THROW_IF_NOT(check(ep_idx, &(send_reqs[i])) != ATL_STATUS_FAILURE, + "check send failed"); + is_completed = false; + } + if (!recv_reqs[i].is_completed) { + CCL_THROW_IF_NOT(check(ep_idx, &(recv_reqs[i])) != ATL_STATUS_FAILURE, + "check recv failed"); + is_completed = false; + } + } + } +} diff --git a/src/atl/ofi/atl_ofi_comm.hpp b/src/atl/ofi/atl_ofi_comm.hpp new file mode 100644 index 000000000..068a43584 --- /dev/null +++ b/src/atl/ofi/atl_ofi_comm.hpp @@ -0,0 +1,245 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "atl/atl_base_comm.hpp" +#include "atl/ofi/atl_ofi.hpp" + +class atl_ofi_comm : public atl_base_comm { +public: + ~atl_ofi_comm() override; + atl_ofi_comm(); + atl_ofi_comm(std::shared_ptr k); + atl_ofi_comm(int total_rank_count, + const std::vector& ranks, + std::shared_ptr k); + + atl_status_t main_addr_reserve(char* main_addr) override { + return pmi->pmrt_main_addr_reserve(main_addr); + } + + atl_status_t finalize() override { + ATL_CHECK_STATUS(pmi->pmrt_finalize(), "failed to finalize PMI"); + + return transport->finalize(); + } + + atl_status_t update() override { + return transport->update(pmi); + } + + atl_status_t wait_notification() override { + return pmi->pmrt_wait_notification(); + } + + atl_status_t set_resize_function(atl_resize_fn_t fn) override { + return pmi->pmrt_set_resize_function(fn); + } + + atl_status_t mr_reg(const void* buf, size_t len, atl_mr_t** mr) override { + return transport->mr_reg(buf, len, mr); + } + + atl_status_t mr_dereg(atl_mr_t* mr) override { + return transport->mr_dereg(mr); + } + + atl_status_t send(size_t ep_idx, + const void* buf, + size_t len, + int dst_proc_idx, + uint64_t tag, + atl_req_t* req) override { + return transport->send(eps[ep_idx], buf, len, rank2rank_map[dst_proc_idx], tag, req); + } + + atl_status_t recv(size_t ep_idx, + void* buf, + size_t len, + int src_proc_idx, + uint64_t tag, + atl_req_t* req) override { + return transport->recv(eps[ep_idx], buf, len, rank2rank_map[src_proc_idx], tag, req); + } + + atl_status_t probe(size_t ep_idx, + int src_proc_idx, + uint64_t tag, + int* found, + size_t* recv_len) override { + return transport->probe(eps[ep_idx], rank2rank_map[src_proc_idx], tag, found, recv_len); + } + + atl_status_t allgatherv(size_t ep_idx, + const void* send_buf, + size_t send_len, + void* recv_buf, + const int* recv_lens, + const int* offsets, + atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t allreduce(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t alltoall(size_t ep_idx, + const void* send_buf, + void* recv_buf, + int len, + atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t alltoallv(size_t ep_idx, + const void* send_buf, + const int* send_lens, + const int* send_offsets, + void* recv_buf, + const int* recv_lens, + const int* recv_offsets, + atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t barrier(size_t ep_idx, atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t bcast(size_t ep_idx, void* buf, size_t len, int root, atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t reduce(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t len, + int root, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t reduce_scatter(size_t ep_idx, + const void* send_buf, + void* recv_buf, + size_t recv_len, + atl_datatype_t dtype, + atl_reduction_t op, + atl_req_t* req) override { + return ATL_STATUS_UNSUPPORTED; + } + + atl_status_t read(size_t ep_idx, + void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) override { + return transport->read( + eps[ep_idx], buf, len, mr, addr, remote_key, rank2rank_map[dst_proc_idx], req); + } + + atl_status_t write(size_t ep_idx, + const void* buf, + size_t len, + atl_mr_t* mr, + uint64_t addr, + uintptr_t remote_key, + int dst_proc_idx, + atl_req_t* req) override { + return transport->write( + eps[ep_idx], buf, len, mr, addr, remote_key, rank2rank_map[dst_proc_idx], req); + } + + atl_status_t wait(size_t ep_idx, atl_req_t* req) override { + return transport->wait(eps[ep_idx], req); + } + + atl_status_t wait_all(size_t ep_idx, atl_req_t* req, size_t count) override { + return transport->wait_all(eps[ep_idx], req, count); + } + + atl_status_t cancel(size_t ep_idx, atl_req_t* req) override { + return transport->cancel(eps[ep_idx], req); + } + + atl_status_t poll(size_t ep_idx) override { + return transport->poll(eps[ep_idx]); + } + + atl_status_t check(size_t ep_idx, atl_req_t* req) override { + return transport->check(eps[ep_idx], req); + } + + size_t get_threads_per_process() override { + return threads_per_process; + } + + size_t get_ranks_per_process() override { + return ranks_per_process; + } + + int get_rank() override { + return rank; + } + + int get_size() override { + return size; + } + + int get_r2r_color() override { + return coord.local_idx; + } + + int get_host_color() override { + return coord.hostname_hash; + } + + /* + * TODO: Temporary change. + * Need to define correct to unique id + */ + size_t get_id() override { + return 0; + } + + std::shared_ptr comm_split(int color) override; + + std::vector get_rank2rank_map() override { + return rank2rank_map; + } + +private: + static atl_ofi* transport; + std::vector eps; + static std::atomic comm_count; + + atl_ofi_comm(atl_ofi_comm* parent, int color); + atl_status_t init_transport(bool is_new); + using rank_info_t = std::tuple; + void rank_info_exchange(std::vector& ranks_info, rank_info_t rank_info); +}; diff --git a/src/atl/ofi/atl_ofi_helper.cpp b/src/atl/ofi/atl_ofi_helper.cpp index 7218da8d1..8caa2fd53 100644 --- a/src/atl/ofi/atl_ofi_helper.cpp +++ b/src/atl/ofi/atl_ofi_helper.cpp @@ -84,6 +84,50 @@ std::string atl_ofi_get_nic_name(const struct fi_info* prov) { return ss.str(); } +const char* atl_ofi_link_state_str(enum fi_link_state state) { + switch (state) { + case FI_LINK_DOWN: return "down"; + case FI_LINK_UP: return "up"; + default: return "unknown"; + } +} + +std::string atl_ofi_get_nic_info(const struct fi_info* prov) { + std::stringstream ss; + + ss << "{ "; + + ss << "name " << atl_ofi_get_nic_name(prov); + + if (prov->nic && prov->nic->link_attr) { + ss << ", state " << atl_ofi_link_state_str(prov->nic->link_attr->state); + + if (prov->nic->link_attr->mtu) { + ss << ", mtu " << prov->nic->link_attr->mtu << " bytes"; + } + + if (prov->nic->link_attr->speed) { + const float bits_to_gbytes_coef = 8.0 * 1000 * 1000 * 1000; + ss << ", speed " << (float)prov->nic->link_attr->speed / bits_to_gbytes_coef << " GB/s"; + } + + if (prov->nic->link_attr->address) { + ss << ", address " << prov->nic->link_attr->address; + } + + if (prov->nic->link_attr->network_type) { + ss << ", network_type " << prov->nic->link_attr->network_type; + } + } + else { + ss << ", no link attr"; + } + + ss << " }"; + + return ss.str(); +} + atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_size) { size_t prov_idx; atl_ofi_ctx_t* ofi_ctx = container_of(ep->ctx, atl_ofi_ctx_t, ctx); @@ -107,9 +151,11 @@ atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_siz prov_idx = ofi_ctx->nw_prov_first_idx + nw_prov_offset; } - LOG_DEBUG("get_prov: ep_idx ", + LOG_DEBUG("select nic: ep_idx ", ep->idx, - ", prov_idx ", + ", local_proc_idx ", + coord->local_idx, + ", nic_idx ", prov_idx, ", my_node_idx ", my_node_idx, @@ -136,7 +182,7 @@ fi_addr_t atl_ofi_get_addr(atl_ctx_t* ctx, atl_ofi_prov_t* prov, int proc_idx, s return *(prov->addr_table + ((ctx->ep_count * (proc_idx - prov->first_proc_idx)) + ep_idx)); } -atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::unique_ptr& pmi) { +atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::shared_ptr pmi) { CCL_THROW_IF_NOT(ofi_ctx, "ofi_ctx is null"); atl_proc_coord_t* coord = &(ofi_ctx->ctx.coord); @@ -178,7 +224,7 @@ atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::unique_pt goto fn_err; } - pmi->pmrt_barrier(); + ATL_CHECK_STATUS(pmi->pmrt_barrier(), "barrier failed"); all_hostnames = (char*)calloc(1, coord->global_count * ATL_MAX_HOSTNAME_LEN); if (!all_hostnames) { @@ -225,7 +271,7 @@ atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::unique_pt atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, size_t prov_idx, - std::unique_ptr& pmi) { + std::shared_ptr pmi) { CCL_THROW_IF_NOT(ofi_ctx, "ofi_ctx is null"); atl_ctx_t* ctx = &(ofi_ctx->ctx); @@ -286,7 +332,7 @@ atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, return ATL_STATUS_FAILURE; } - pmi->pmrt_barrier(); + ATL_CHECK_STATUS(pmi->pmrt_barrier(), "barrier failed"); /* retrieve all OFI EP names in order */ for (i = 0; i < ctx->coord.global_count; i++) { @@ -445,7 +491,7 @@ atl_status_t atl_ofi_prov_ep_get_name(atl_ofi_prov_t* prov, size_t ep_idx) { atl_status_t atl_ofi_prov_eps_connect(atl_ofi_ctx_t* ofi_ctx, size_t prov_idx, - std::unique_ptr& pmi) { + std::shared_ptr pmi) { int ret; size_t ep_idx; @@ -811,6 +857,8 @@ atl_status_t atl_ofi_set_env(const atl_attr_t& attr) { setenv("FI_SHM_DISABLE_CMA", "1", 0); #endif // CCL_ENABLE_SYCL + setenv("FI_MLX_MULTI_EP", "1", 0); + atl_ofi_adjust_env(attr); /* @@ -889,7 +937,7 @@ atl_status_t atl_ofi_prov_init(atl_ctx_t* ctx, struct fi_info* info, atl_ofi_prov_t* prov, atl_attr_t* attr, - std::unique_ptr& pmi) { + std::shared_ptr pmi) { struct fi_av_attr av_attr; size_t ep_idx = 0; ssize_t ret = 0; @@ -900,7 +948,7 @@ atl_status_t atl_ofi_prov_init(atl_ctx_t* ctx, if (ctx->coord.global_idx == 0) { LOG_INFO("provider: ", info->fabric_attr->prov_name); - LOG_INFO(" nic: ", atl_ofi_get_nic_name(info)); + LOG_INFO(" nic: ", atl_ofi_get_nic_info(info)); LOG_INFO(" mr_mode: ", info->domain_attr->mr_mode); LOG_INFO(" threading: ", info->domain_attr->threading); LOG_INFO(" tx_ctx_cnt: ", info->domain_attr->tx_ctx_cnt); @@ -983,15 +1031,15 @@ atl_status_t atl_ofi_adjust_out_tag(atl_ofi_prov_t* prov, atl_attr_t* attr) { const char* prov_name = prov->info->fabric_attr->prov_name; - CCL_THROW_IF_NOT(attr->out.tag_bits > 0, - "unexpected tag_bits ", - attr->out.tag_bits, - " for prov ", - prov_name); - - CCL_THROW_IF_NOT( - attr->out.max_tag > 0, "unexpected max_tag ", attr->out.max_tag, " for prov ", prov_name); + if (!(attr->out.tag_bits > 0)) { + LOG_ERROR("unexpected tag_bits ", attr->out.tag_bits, " for prov ", prov_name); + return ATL_STATUS_FAILURE; + } + if (!(attr->out.max_tag > 0)) { + LOG_ERROR("unexpected max_tag ", attr->out.max_tag, " for prov ", prov_name); + return ATL_STATUS_FAILURE; + } LOG_INFO(prov_name, " tag_bits: ", attr->out.tag_bits, @@ -1003,12 +1051,21 @@ atl_status_t atl_ofi_adjust_out_tag(atl_ofi_prov_t* prov, atl_attr_t* attr) { return ATL_STATUS_SUCCESS; } +static bool atl_ofi_is_nic_down(struct fi_info* prov) { + if (prov->nic && prov->nic->link_attr->state == FI_LINK_DOWN) { + return true; + } + + return false; +} + /* determine if NIC has already been included in others */ int atl_ofi_nic_already_used(const struct fi_info* prov, - struct fi_info** others, - size_t nic_count) { - for (size_t i = 0; i < nic_count; i++) { - if (prov->nic && others[i]->nic && prov->nic->bus_attr->bus_type == FI_BUS_PCI && + const std::vector& others, + bool check_pci = false) { + for (size_t i = 0; i < others.size(); i++) { + if (check_pci && prov->nic && others[i]->nic && + prov->nic->bus_attr->bus_type == FI_BUS_PCI && others[i]->nic->bus_attr->bus_type == FI_BUS_PCI) { struct fi_pci_attr pci = prov->nic->bus_attr->attr.pci; struct fi_pci_attr other_pci = others[i]->nic->bus_attr->attr.pci; @@ -1164,23 +1221,31 @@ int atl_ofi_is_allowed_nic_name(atl_ofi_ctx_t* ofi_ctx, struct fi_info* info) { return (should_include && !should_exclude); } +bool atl_ofi_compare_nics(const struct fi_info* nic1, const struct fi_info* nic2) { + if (nic1->nic && !nic2->nic) { + return true; + } + else if (!nic1->nic && nic2->nic) { + return false; + } + return (atl_ofi_get_short_nic_name(nic1) < atl_ofi_get_short_nic_name(nic2)); +} + atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, struct fi_info* base_hints, atl_attr_t* attr, - std::unique_ptr& pmi) { + std::shared_ptr pmi) { atl_status_t ret = ATL_STATUS_SUCCESS; struct fi_info* prov_list = nullptr; struct fi_info* prov_iter = nullptr; size_t idx = 0, prov_idx = 0; char* prov_name = nullptr; atl_ofi_prov_t* prov = nullptr; - size_t name_prov_count = 0; - size_t topo_prov_count = 0; - size_t final_prov_count = 0; - struct fi_info* name_prov_list[ATL_OFI_MAX_NW_PROV_COUNT] = { 0 }; - struct fi_info* topo_prov_list[ATL_OFI_MAX_NW_PROV_COUNT] = { 0 }; - struct fi_info* final_prov_list[ATL_OFI_MAX_NW_PROV_COUNT] = { 0 }; + std::vector name_provs; + std::vector topo_provs; + std::vector final_provs; std::set all_nic_names; + int prov_offset = 0; atl_ofi_ctx_t* ofi_ctx = container_of(ctx, atl_ofi_ctx_t, ctx); @@ -1197,18 +1262,23 @@ atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, prov_iter = prov_list; while (prov_iter) { LOG_DEBUG("name filter: check nic ", atl_ofi_get_nic_name(prov_iter)); - if (!atl_ofi_nic_already_used(prov_iter, name_prov_list, name_prov_count)) { + if (atl_ofi_is_nic_down(prov_iter)) { + LOG_DEBUG("nic ", atl_ofi_get_nic_name(prov_iter), " is in down state, skip"); + } + else if (!atl_ofi_nic_already_used(prov_iter, name_provs)) { all_nic_names.insert(atl_ofi_get_short_nic_name(prov_iter)); if (atl_ofi_is_allowed_nic_name(ofi_ctx, prov_iter)) { LOG_DEBUG("name filter: found suitable nic ", atl_ofi_get_nic_name(prov_iter)); - name_prov_list[name_prov_count] = fi_dupinfo(prov_iter); - name_prov_count++; + name_provs.push_back(fi_dupinfo(prov_iter)); } } prov_iter = prov_iter->next; } - if (!name_prov_count) { + /* sort by names */ + std::sort(name_provs.begin(), name_provs.end(), atl_ofi_compare_nics); + + if (name_provs.empty()) { LOG_ERROR("name filter: can not find network providers", ", include names: ", vec_to_string(ofi_ctx->mnic_include_names), @@ -1221,13 +1291,12 @@ atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, /* 3. filter out by topo */ if (ofi_ctx->mnic_type == ATL_MNIC_NONE) { - topo_prov_list[topo_prov_count] = fi_dupinfo(name_prov_list[0]); - topo_prov_count++; + topo_provs.push_back(fi_dupinfo(name_provs[0])); } else { struct fid_nic* nic = nullptr; - for (idx = 0; idx < name_prov_count; idx++) { - prov_iter = name_prov_list[idx]; + for (idx = 0; idx < name_provs.size(); idx++) { + prov_iter = name_provs[idx]; LOG_DEBUG("topo filter: check nic ", atl_ofi_get_nic_name(prov_iter)); nic = prov_iter->nic; @@ -1236,15 +1305,14 @@ atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, ", has nic_attr ", (nic != nullptr)); - if (!atl_ofi_nic_already_used(prov_iter, topo_prov_list, topo_prov_count)) { + if (!atl_ofi_nic_already_used(prov_iter, topo_provs)) { int is_local = atl_ofi_is_nic_local(prov_iter); LOG_DEBUG( "topo filter: nic ", atl_ofi_get_nic_name(prov_iter), ", is_local ", is_local); if (ofi_ctx->mnic_type == ATL_MNIC_GLOBAL || (ofi_ctx->mnic_type == ATL_MNIC_LOCAL && is_local)) { LOG_DEBUG("topo filter: found suitable nic ", atl_ofi_get_nic_name(prov_iter)); - topo_prov_list[topo_prov_count] = fi_dupinfo(prov_iter); - topo_prov_count++; + topo_provs.push_back(fi_dupinfo(prov_iter)); } } else { @@ -1253,58 +1321,64 @@ atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, } } - if (!topo_prov_count) { + if (topo_provs.empty()) { LOG_ERROR("topo filter: can not find network providers, mnic_type ", ofi_ctx->mnic_type); goto err; } - /* 4. filter out by count */ - for (idx = 0; idx < topo_prov_count; idx++) { - prov_iter = topo_prov_list[idx]; + /* 4. reorder according to desired offset */ + if (ofi_ctx->mnic_offset == ATL_MNIC_OFFSET_LOCAL_PROC_IDX) { + prov_offset = ctx->coord.local_idx % topo_provs.size(); + } + LOG_DEBUG("rotate: prov_offset ", prov_offset, ", vec_size ", topo_provs.size()); + std::rotate(topo_provs.begin(), topo_provs.begin() + prov_offset, topo_provs.end()); + + /* 5. filter out by count */ + for (idx = 0; idx < topo_provs.size(); idx++) { + prov_iter = topo_provs[idx]; LOG_DEBUG("count filter: check nic ", atl_ofi_get_nic_name(prov_iter)); - if (final_prov_count < ofi_ctx->mnic_count) { + if (final_provs.size() < ofi_ctx->mnic_count) { LOG_DEBUG("count filter: found suitable nic ", atl_ofi_get_nic_name(prov_iter), ", nic idx ", - final_prov_count); - final_prov_list[final_prov_count] = fi_dupinfo(prov_iter); - final_prov_count++; + final_provs.size()); + final_provs.push_back(fi_dupinfo(prov_iter)); } else { break; } } - if (!final_prov_count) { + if (final_provs.empty()) { LOG_ERROR("count filter: can not find network providers, mnic_count ", ofi_ctx->mnic_count); goto err; } - /* 5. create network providers */ - LOG_INFO("found ", final_prov_count, " nic(s) according to all filters"); - ofi_ctx->nw_prov_count = final_prov_count; + /* 6. create network providers */ + LOG_INFO("found ", final_provs.size(), " nic(s) according to all filters"); + ofi_ctx->nw_prov_count = final_provs.size(); for (idx = 0; idx < ofi_ctx->nw_prov_count; idx++) { prov_idx = ofi_ctx->nw_prov_first_idx + idx; prov = &ofi_ctx->provs[prov_idx]; prov->idx = prov_idx; prov->is_shm = 0; - ATL_CALL(atl_ofi_prov_init(ctx, final_prov_list[idx], prov, attr, pmi), goto err); + ATL_CALL(atl_ofi_prov_init(ctx, final_provs[idx], prov, attr, pmi), goto err); } exit: - for (idx = 0; idx < final_prov_count; idx++) { - if (final_prov_list[idx]) - fi_freeinfo(final_prov_list[idx]); + for (idx = 0; idx < final_provs.size(); idx++) { + if (final_provs[idx]) + fi_freeinfo(final_provs[idx]); } - for (idx = 0; idx < topo_prov_count; idx++) { - if (topo_prov_list[idx]) - fi_freeinfo(topo_prov_list[idx]); + for (idx = 0; idx < topo_provs.size(); idx++) { + if (topo_provs[idx]) + fi_freeinfo(topo_provs[idx]); } - for (idx = 0; idx < name_prov_count; idx++) { - if (name_prov_list[idx]) - fi_freeinfo(name_prov_list[idx]); + for (idx = 0; idx < name_provs.size(); idx++) { + if (name_provs[idx]) + fi_freeinfo(name_provs[idx]); } fi_freeinfo(prov_list); @@ -1318,3 +1392,11 @@ atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, ret = ATL_STATUS_FAILURE; goto exit; } + +void atl_ofi_init_req(atl_req_t* req, atl_ofi_prov_ep_t* prov_ep, struct fid_ep* fi_ep) { + atl_ofi_req_t* ofi_req = ((atl_ofi_req_t*)req->internal); + ofi_req->prov_ep = prov_ep; + ofi_req->fi_ep = fi_ep; + ofi_req->comp_state = ATL_OFI_COMP_POSTED; + req->is_completed = 0; +} diff --git a/src/atl/ofi/atl_ofi_helper.hpp b/src/atl/ofi/atl_ofi_helper.hpp index 59e776e24..c8cc040af 100644 --- a/src/atl/ofi/atl_ofi_helper.hpp +++ b/src/atl/ofi/atl_ofi_helper.hpp @@ -34,11 +34,11 @@ #include #include -#include "atl.h" +#include "atl/util/pm/pm_rt.h" #include "common/global/global.hpp" #include "hwloc/hwloc_wrapper.hpp" #ifdef CCL_ENABLE_OFI_HMEM -#include "sched/entry/gpu/ze_primitives.hpp" +#include "sched/entry/ze/ze_primitives.hpp" #endif // CCL_ENABLE_OFI_HMEM #define ATL_OFI_BASE_PM_KEY "atl-ofi" @@ -113,7 +113,7 @@ CCL_THROW("OFI function error"); \ break; \ } \ - (void)atl_ep_poll(ep); \ + (void)poll(ep); \ retry_count++; \ } while (((ret_val) == -FI_EAGAIN) && (retry_count < max_retry_count)); \ } while (0) @@ -213,6 +213,7 @@ typedef struct { std::vector mnic_include_names; std::vector mnic_exclude_names; size_t mnic_count; + atl_mnic_offset_t mnic_offset; int enable_hmem; } atl_ofi_ctx_t; @@ -275,14 +276,14 @@ std::string atl_ofi_get_short_nic_name(const struct fi_info* prov); std::string atl_ofi_get_nic_name(const struct fi_info* prov); atl_ofi_prov_t* atl_ofi_get_prov(atl_ep_t* ep, int peer_proc_idx, size_t msg_size); fi_addr_t atl_ofi_get_addr(atl_ctx_t* ctx, atl_ofi_prov_t* prov, int proc_idx, size_t ep_idx); -atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::unique_ptr& pmi); +atl_status_t atl_ofi_get_local_proc_coord(atl_ofi_ctx_t* ofi_ctx, std::shared_ptr pmi); atl_status_t atl_ofi_prov_update_addr_table(atl_ofi_ctx_t* ofi_ctx, size_t prov_idx, - std::unique_ptr& pmi); + std::shared_ptr pmi); atl_status_t atl_ofi_prov_ep_get_name(atl_ofi_prov_t* prov, size_t ep_idx); atl_status_t atl_ofi_prov_eps_connect(atl_ofi_ctx_t* ofi_ctx, size_t prov_idx, - std::unique_ptr& pmi); + std::shared_ptr pmi); void atl_ofi_prov_ep_destroy(atl_ofi_prov_t* prov, atl_ofi_prov_ep_t* ep); void atl_ofi_prov_destroy(atl_ctx_t* ctx, atl_ofi_prov_t* prov); int atl_ofi_wait_cancel_cq(struct fid_cq* cq); @@ -300,13 +301,12 @@ atl_status_t atl_ofi_prov_init(atl_ctx_t* ctx, struct fi_info* info, atl_ofi_prov_t* prov, atl_attr_t* attr, - std::unique_ptr& pmi); + std::shared_ptr pmi); atl_status_t atl_ofi_adjust_out_tag(atl_ofi_prov_t* prov, atl_attr_t* attr); -int atl_ofi_nic_already_used(const struct fi_info* prov, struct fi_info** others, size_t nic_count); -int atl_ofi_is_nic_local(struct fi_info* info); atl_status_t atl_ofi_parse_mnic_name(atl_ctx_t* ctx, std::string str_to_parse); int atl_ofi_is_allowed_nic_name(atl_ofi_ctx_t* ofi_ctx, struct fi_info* info); atl_status_t atl_ofi_open_nw_provs(atl_ctx_t* ctx, struct fi_info* base_hints, atl_attr_t* attr, - std::unique_ptr& pmi); + std::shared_ptr pmi); +void atl_ofi_init_req(atl_req_t* req, atl_ofi_prov_ep_t* prov_ep, struct fid_ep* fi_ep); diff --git a/src/atl/util/pm/pm_rt.h b/src/atl/util/pm/pm_rt.h index ac328fddc..e896b4d52 100644 --- a/src/atl/util/pm/pm_rt.h +++ b/src/atl/util/pm/pm_rt.h @@ -148,7 +148,7 @@ static inline atl_status_t pmrt_kvs_get(pm_rt_desc_t *pmrt_desc, #ifdef __cplusplus class ipmi { public: - virtual ~ipmi() = default; + virtual ~ipmi() noexcept(false){}; virtual int is_pm_resize_enabled() = 0; @@ -160,9 +160,9 @@ class ipmi { virtual atl_status_t pmrt_wait_notification() = 0; - virtual void pmrt_finalize() = 0; + virtual atl_status_t pmrt_finalize() = 0; - virtual void pmrt_barrier() = 0; + virtual atl_status_t pmrt_barrier() = 0; virtual atl_status_t pmrt_kvs_put(char *kvs_key, int proc_idx, @@ -180,13 +180,15 @@ class ipmi { virtual size_t get_local_thread_idx() = 0; - virtual size_t get_local_kvs_id() = 0; + virtual atl_status_t get_local_kvs_id(size_t &res) = 0; - virtual void set_local_kvs_id(size_t local_kvs_id) = 0; + virtual atl_status_t set_local_kvs_id(size_t local_kvs_id) = 0; virtual size_t get_threads_per_process() = 0; virtual size_t get_ranks_per_process() = 0; + + virtual atl_status_t pmrt_init() = 0; }; #endif #endif // PM_RT_H diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp index ec1f9ba30..b642f3b5d 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.cpp @@ -26,52 +26,54 @@ int pmi_resizable::is_pm_resize_enabled() { return true; } -atl_status_t pmi_resizable::pmrt_init(const char *main_addr) { - int ret; +atl_status_t pmi_resizable::pmrt_init() { + kvs_status_t ret; size_t max_kvsnamelen; - ret = PMIR_Init(main_addr); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; + KVS_2_ATL_CHECK_STATUS(PMIR_Init(main_addr.data()), "failed to init"); - ret = PMIR_Update(); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; + KVS_2_ATL_CHECK_STATUS(PMIR_Update(), "failed to update"); ret = PMIR_Get_size(&size); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_resizable; ret = PMIR_Get_rank(&rank); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_resizable; ret = PMIR_KVS_Get_name_length_max(&max_kvsnamelen); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_resizable; kvsname = (char *)calloc(1, max_kvsnamelen); - if (!kvsname) + if (!kvsname) { + LOG_ERROR("memory allocaion failed"); goto err_resizable; + } ret = PMIR_KVS_Get_my_name(kvsname, max_kvsnamelen); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_alloc_key; ret = PMIR_KVS_Get_key_length_max(&max_keylen); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_alloc_key; key_storage = (char *)calloc(1, max_keylen); - if (!key_storage) + if (!key_storage) { + LOG_ERROR("memory allocaion failed"); goto err_alloc_key; + } ret = PMIR_KVS_Get_value_length_max(&max_vallen); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_alloc_val; val_storage = (char *)calloc(1, max_vallen); - if (!val_storage) + if (!val_storage) { + LOG_ERROR("memory allocaion failed"); goto err_alloc_val; + } initialized = true; @@ -82,76 +84,80 @@ atl_status_t pmi_resizable::pmrt_init(const char *main_addr) { free(kvsname); err_resizable: PMIR_Finalize(); + LOG_ERROR("failed"); return ATL_STATUS_FAILURE; } atl_status_t pmi_resizable::pmrt_main_addr_reserve(char *main_addr) { - int ret = PMIR_Main_Addr_Reserve(main_addr); - - if (ret) + if (PMIR_Main_Addr_Reserve(main_addr) != KVS_STATUS_SUCCESS) return ATL_STATUS_FAILURE; return ATL_STATUS_SUCCESS; } atl_status_t pmi_resizable::pmrt_set_resize_function(atl_resize_fn_t resize_fn) { - int ret = PMIR_set_resize_function((pmir_resize_fn_t)resize_fn); - - if (ret) + if (PMIR_set_resize_function((pmir_resize_fn_t)resize_fn) != KVS_STATUS_SUCCESS) return ATL_STATUS_FAILURE; return ATL_STATUS_SUCCESS; } atl_status_t pmi_resizable::pmrt_update() { - int ret; + kvs_status_t ret; ret = PMIR_Update(); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_resizable; ret = PMIR_Get_size(&size); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_resizable; ret = PMIR_Get_rank(&rank); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) goto err_resizable; return ATL_STATUS_SUCCESS; err_resizable: PMIR_Finalize(); + LOG_ERROR("failed"); return ATL_STATUS_FAILURE; } atl_status_t pmi_resizable::pmrt_wait_notification() { - int ret; + kvs_status_t ret; ret = PMIR_Wait_notification(); - if (ret != PMIR_SUCCESS) + if (ret != KVS_STATUS_SUCCESS) return ATL_STATUS_FAILURE; return ATL_STATUS_SUCCESS; } -void pmi_resizable::pmrt_finalize() { +atl_status_t pmi_resizable::pmrt_finalize() { is_finalized = true; if (!initialized) - return; + return ATL_STATUS_SUCCESS; free(kvsname); free(key_storage); free(val_storage); - PMIR_Finalize(); + if (PMIR_Finalize() != KVS_STATUS_SUCCESS) { + return ATL_STATUS_FAILURE; + } + return ATL_STATUS_SUCCESS; } -void pmi_resizable::pmrt_barrier() { +atl_status_t pmi_resizable::pmrt_barrier() { if (!initialized) - return; + return ATL_STATUS_SUCCESS; - PMIR_Barrier(); + if (PMIR_Barrier() != KVS_STATUS_SUCCESS) { + return ATL_STATUS_FAILURE; + } + return ATL_STATUS_SUCCESS; } atl_status_t pmi_resizable::pmrt_kvs_put(char *kvs_key, @@ -160,28 +166,31 @@ atl_status_t pmi_resizable::pmrt_kvs_put(char *kvs_key, size_t kvs_val_len) { int ret; - if (!initialized) + if (!initialized) { + LOG_ERROR("not initialized yet") return ATL_STATUS_FAILURE; + } - if (kvs_val_len > max_vallen) + if (kvs_val_len > max_vallen) { + LOG_ERROR("asked len > max len"); return ATL_STATUS_FAILURE; + } ret = snprintf(key_storage, max_keylen - 1, RESIZABLE_PMI_RT_KEY_FORMAT, kvs_key, proc_idx); - if (ret < 0) + if (ret < 0) { + LOG_ERROR("snprintf failed"); return ATL_STATUS_FAILURE; + } ret = encode(kvs_val, kvs_val_len, val_storage, max_vallen); - if (ret) + if (ret) { + LOG_ERROR("encode failed"); return ATL_STATUS_FAILURE; + } - ret = PMIR_KVS_Put(kvsname, key_storage, val_storage); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - ret = PMIR_KVS_Commit(kvsname); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; + KVS_2_ATL_CHECK_STATUS(PMIR_KVS_Put(kvsname, key_storage, val_storage), "put failed"); + KVS_2_ATL_CHECK_STATUS(PMIR_KVS_Commit(kvsname), "commit failed"); return ATL_STATUS_SUCCESS; } @@ -191,20 +200,25 @@ atl_status_t pmi_resizable::pmrt_kvs_get(char *kvs_key, size_t kvs_val_len) { int ret; - if (!initialized) + if (!initialized) { + LOG_ERROR("not initialized yet") return ATL_STATUS_FAILURE; + } ret = snprintf(key_storage, max_keylen - 1, RESIZABLE_PMI_RT_KEY_FORMAT, kvs_key, proc_idx); - if (ret < 0) + if (ret < 0) { + LOG_ERROR("snprintf failed"); return ATL_STATUS_FAILURE; + } - ret = PMIR_KVS_Get(kvsname, key_storage, val_storage, max_vallen); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; + KVS_2_ATL_CHECK_STATUS(PMIR_KVS_Get(kvsname, key_storage, val_storage, max_vallen), + "get failed"); ret = decode(val_storage, kvs_val, kvs_val_len); - if (ret) + if (ret) { + LOG_ERROR("decode failed"); return ATL_STATUS_FAILURE; + } return ATL_STATUS_SUCCESS; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h index 8ec4023d0..d41302548 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable.h @@ -48,10 +48,9 @@ class helper; class pmi_resizable final : public ipmi { public: pmi_resizable() = delete; - explicit pmi_resizable(std::shared_ptr k, const char* main_addr = nullptr) { + explicit pmi_resizable(std::shared_ptr k, const char* main_addr = "") + : main_addr(main_addr) { h = std::shared_ptr(new helper(k)); - //TODO: move it in one func - pmrt_init(main_addr); } ~pmi_resizable() override; @@ -66,7 +65,7 @@ class pmi_resizable final : public ipmi { atl_status_t pmrt_wait_notification() override; - void pmrt_barrier() override; + atl_status_t pmrt_barrier() override; atl_status_t pmrt_kvs_put(char* kvs_key, int proc_idx, @@ -78,7 +77,7 @@ class pmi_resizable final : public ipmi { void* kvs_val, size_t kvs_val_len) override; - void Hard_finilize(int sig); + kvs_status_t hard_finalize(int sig); int get_rank() override; @@ -86,9 +85,9 @@ class pmi_resizable final : public ipmi { size_t get_local_thread_idx() override; - size_t get_local_kvs_id() override; + atl_status_t get_local_kvs_id(size_t& res) override; - void set_local_kvs_id(size_t local_kvs_id) override; + atl_status_t set_local_kvs_id(size_t local_kvs_id) override; size_t get_threads_per_process() override { return 1; @@ -97,49 +96,51 @@ class pmi_resizable final : public ipmi { size_t get_ranks_per_process() override { return 1; } - void pmrt_finalize() override; + atl_status_t pmrt_finalize() override; + + atl_status_t pmrt_init() override; private: bool is_finalized{ false }; - atl_status_t pmrt_init(const char* main_addr = nullptr); /*Was in API ->*/ - int PMIR_Main_Addr_Reserve(char* main_addr); + kvs_status_t PMIR_Main_Addr_Reserve(char* main_addr); - int PMIR_Init(const char* main_addr); + kvs_status_t PMIR_Init(const char* main_addr); - int PMIR_Finalize(void); + kvs_status_t PMIR_Finalize(void); - int PMIR_Get_size(int* size); + kvs_status_t PMIR_Get_size(int* size); - int PMIR_Get_rank(int* rank); + kvs_status_t PMIR_Get_rank(int* rank); - int PMIR_KVS_Get_my_name(char* kvs_name, size_t length); + kvs_status_t PMIR_KVS_Get_my_name(char* kvs_name, size_t length); - int PMIR_KVS_Get_name_length_max(size_t* length); + kvs_status_t PMIR_KVS_Get_name_length_max(size_t* length); - int PMIR_Barrier(void); + kvs_status_t PMIR_Barrier(void); - int PMIR_Update(void); + kvs_status_t PMIR_Update(void); - int PMIR_KVS_Get_key_length_max(size_t* length); + kvs_status_t PMIR_KVS_Get_key_length_max(size_t* length); - int PMIR_KVS_Get_value_length_max(size_t* length); + kvs_status_t PMIR_KVS_Get_value_length_max(size_t* length); - int PMIR_KVS_Put(const char* kvs_name, const char* key, const char* value); + kvs_status_t PMIR_KVS_Put(const char* kvs_name, const char* key, const char* value); - int PMIR_KVS_Commit(const char* kvs_name); + kvs_status_t PMIR_KVS_Commit(const char* kvs_name); - int PMIR_KVS_Get(const char* kvs_name, const char* key, char* value, size_t length); + kvs_status_t PMIR_KVS_Get(const char* kvs_name, const char* key, char* value, size_t length); - int PMIR_set_resize_function(pmir_resize_fn_t resize_fn); + kvs_status_t PMIR_set_resize_function(pmir_resize_fn_t resize_fn); - int PMIR_Wait_notification(void); + kvs_status_t PMIR_Wait_notification(void); /* <- Was in API*/ kvs_resize_action_t default_checker(int comm_size); kvs_resize_action_t call_resize_fn(int comm_size); int rank; int size; + std::string main_addr; pmir_resize_fn_t resize_function = nullptr; std::shared_ptr h; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h index 98b06efc6..e6fa97330 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/def.h @@ -24,13 +24,22 @@ #include "common/log/log.hpp" +typedef enum { KVS_STATUS_SUCCESS, KVS_STATUS_FAILURE, KVS_STATUS_UNSUPPORTED } kvs_status_t; + +#define KVS_CHECK_STATUS(expr, str) \ + do { \ + if (expr != KVS_STATUS_SUCCESS) { \ + LOG_ERROR(str); \ + return KVS_STATUS_FAILURE; \ + } \ + } while (0) + //TODO: change exit to something more useful #define SET_STR(dst, size, ...) \ do { \ if (snprintf(dst, size, __VA_ARGS__) > size) { \ - printf("line too long (must be shorter %d)\n", size); \ - printf(__VA_ARGS__); \ - exit(1); \ + LOG_ERROR("line too long, must be shorter ", size); \ + return KVS_STATUS_FAILURE; \ } \ } while (0) @@ -38,8 +47,8 @@ do { \ char* res = expr; \ if (!res || res != str) { \ - printf("fgets error\n"); \ - exit(EXIT_FAILURE); \ + LOG_ERROR("fgets error: ", strerror(errno)); \ + return KVS_STATUS_FAILURE; \ } \ } while (0) @@ -61,15 +70,16 @@ buf, \ size, \ shift); \ - perror("read/write error"); \ - exit(EXIT_FAILURE); \ + LOG_ERROR("read/write error: ", strerror(errno)); \ + return KVS_STATUS_FAILURE; \ } \ } \ else if (res == 0) { \ - printf("" #msg ": " #op ": can not process all data, size %zu, shift %zu\n", \ - size, \ - shift); \ - exit(EXIT_FAILURE); \ + LOG_ERROR("" #msg ": " #op \ + ": can not process all data, size %zu, shift %zu\n", \ + size, \ + shift); \ + return KVS_STATUS_FAILURE; \ } \ else { \ shift += res; \ @@ -94,8 +104,8 @@ buf, \ size, \ shift); \ - perror("read/write error"); \ - exit(EXIT_FAILURE); \ + LOG_ERROR("read/write error: ", strerror(errno)); \ + return KVS_STATUS_FAILURE; \ } \ } \ else { \ @@ -159,21 +169,23 @@ void inline kvs_str_copy_known_sizes(char* dst, const char* src, size_t bytes) { dst[bytes - 1] = '\0'; } -long int inline safe_strtol(const char* str, char** endptr, int base) { +template +kvs_status_t inline safe_strtol(const char* str, T& val) { errno = 0; - auto val = strtol(str, endptr, base); + val = strtol(str, nullptr, 10); if (errno != 0) { if (errno == EINVAL) { - CCL_THROW("conversion error occurred from: ", str); + LOG_ERROR("conversion error occurred from: ", str); } else if (errno == ERANGE) { - CCL_THROW("the value provided was out of range: ", str); + LOG_ERROR("the value provided was out of range: ", str); } else { - CCL_THROW("strtol error: ", strerror(errno), ", str: ", str); + LOG_ERROR("strtol error: ", strerror(errno), ", str: ", str); } + return KVS_STATUS_FAILURE; } - return val; + return KVS_STATUS_SUCCESS; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp index 4aa2e298f..e5516cf06 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.cpp @@ -24,14 +24,16 @@ size_t barrier_num = 0; size_t up_idx; size_t applied = 0; -rank_list_t* killed_ranks = NULL; +std::list killed_ranks; int killed_ranks_count = 0; -rank_list_t* new_ranks = NULL; +std::list new_ranks; int new_ranks_count = 0; -size_t helper::replace_str(char* str, int old_rank, int new_rank) { - throw std::runtime_error("unexpected path"); +kvs_status_t helper::replace_str(char* str, int old_rank, int new_rank) { + // throw std::runtime_error("unexpected path"); + LOG_ERROR("unexpected path"); + return KVS_STATUS_FAILURE; char old_str[INT_STR_SIZE]; char new_str[INT_STR_SIZE]; @@ -43,8 +45,10 @@ size_t helper::replace_str(char* str, int old_rank, int new_rank) { SET_STR(new_str, INT_STR_SIZE, RANK_TEMPLATE, new_rank); point_to_replace = strstr(str, old_str); - if (point_to_replace == NULL) - return 1; + if (point_to_replace == NULL) { + LOG_ERROR("not found old rank(%d) in str(%s)", old_rank, str); + return KVS_STATUS_FAILURE; + } old_str_size = strlen(old_str); new_str_size = strlen(new_str); @@ -54,25 +58,31 @@ size_t helper::replace_str(char* str, int old_rank, int new_rank) { memmove(point_to_replace + new_str_size, point_to_replace + old_str_size, rest_len); } memcpy(point_to_replace, new_str, new_str_size); - return 0; + return KVS_STATUS_SUCCESS; } -void helper::update_ranks(int* old_count, rank_list_t** origin_list, const char* kvs_name) { +kvs_status_t helper::update_ranks(int* old_count, + std::list& origin_list, + const char* kvs_name) { char** rank_nums = NULL; - size_t rank_count = get_keys_values_by_name(kvs_name, NULL, &rank_nums); + size_t rank_count; + KVS_CHECK_STATUS(get_keys_values_by_name(kvs_name, NULL, &rank_nums, rank_count), + "failed to get values by name"); size_t i; size_t cur_count = 0; if (rank_count == 0) { // *old_count = 0; - return; + return KVS_STATUS_SUCCESS; } - + int rank_num; for (i = 0; i < rank_count; i++) { - if (rank_list_contains(*origin_list, safe_strtol(rank_nums[i], NULL, 10))) + KVS_CHECK_STATUS(safe_strtol(rank_nums[i], rank_num), "failed to to convert rank_num"); + + if (std::find(origin_list.begin(), origin_list.end(), rank_num) != origin_list.end()) continue; - rank_list_add(origin_list, safe_strtol(rank_nums[i], NULL, 10)); + origin_list.push_back(rank_num); cur_count++; } @@ -82,91 +92,100 @@ void helper::update_ranks(int* old_count, rank_list_t** origin_list, const char* free(rank_nums); *old_count += cur_count; + return KVS_STATUS_SUCCESS; } void helper::keep_first_n_up(int prev_new_ranks_count, int prev_killed_ranks_count) { - rank_list_keep_first_n(&killed_ranks, prev_killed_ranks_count); - rank_list_keep_first_n(&new_ranks, prev_new_ranks_count); + killed_ranks.resize(prev_killed_ranks_count); + new_ranks.resize(prev_new_ranks_count); } -void helper::get_update_ranks(void) { - update_ranks(&killed_ranks_count, &killed_ranks, KVS_APPROVED_DEAD_POD); - update_ranks(&new_ranks_count, &new_ranks, KVS_APPROVED_NEW_POD); +kvs_status_t helper::get_update_ranks(void) { + KVS_CHECK_STATUS(update_ranks(&killed_ranks_count, killed_ranks, KVS_APPROVED_DEAD_POD), + "failed to update killed ranks"); + KVS_CHECK_STATUS(update_ranks(&new_ranks_count, new_ranks, KVS_APPROVED_NEW_POD), + "failed to update new ranks"); + return KVS_STATUS_SUCCESS; } -void helper::get_shift(shift_list_t** list) { +void helper::get_shift(std::list& list) { int shift_pods_count = 0; int max_rank_survivor_pod = count_pods; - rank_list_t* cur_new = new_ranks; - rank_list_t* cur_killed = killed_ranks; - - if (killed_ranks != NULL) - rank_list_sort(killed_ranks); - if (new_ranks != NULL) - rank_list_sort(new_ranks); - - while (cur_killed != NULL) { - if (cur_new != NULL) { - shift_list_add(list, cur_killed->rank, cur_killed->rank, CH_T_UPDATE); - cur_new = cur_new->next; + new_ranks.sort(); + killed_ranks.sort(); + auto cur_new = new_ranks.begin(); + auto cur_killed = killed_ranks.begin(); + + while (cur_killed != killed_ranks.end()) { + if (cur_new != new_ranks.end()) { + list.push_back({ *cur_killed, *cur_killed, CH_T_UPDATE }); + cur_new++; } else { - while (rank_list_contains(cur_killed, max_rank_survivor_pod - shift_pods_count - 1) == - 1) + while (std::find(cur_killed, + killed_ranks.end(), + max_rank_survivor_pod - shift_pods_count - 1) != killed_ranks.end()) { max_rank_survivor_pod--; + } - if (cur_killed->rank < max_rank_survivor_pod - shift_pods_count) { - shift_list_add(list, - max_rank_survivor_pod - shift_pods_count - 1, - cur_killed->rank, - CH_T_SHIFT); + if (*cur_killed < max_rank_survivor_pod - shift_pods_count) { + list.push_back( + { max_rank_survivor_pod - shift_pods_count - 1, *cur_killed, CH_T_SHIFT }); shift_pods_count++; } else { - while (cur_killed != NULL) { - shift_list_add(list, cur_killed->rank, cur_killed->rank, CH_T_DEAD); - cur_killed = cur_killed->next; + while (cur_killed != killed_ranks.end()) { + list.push_back({ *cur_killed, *cur_killed, CH_T_DEAD }); + cur_killed++; } break; } } - cur_killed = cur_killed->next; + cur_killed++; } - while (cur_new != NULL) { - shift_list_add(list, cur_new->rank, cur_new->rank, CH_T_NEW); - cur_new = cur_new->next; + while (cur_new != new_ranks.end()) { + list.push_back({ *cur_new, *cur_new, CH_T_NEW }); + cur_new++; } } -void helper::up_pods_count(void) { - count_pods = get_count_names(KVS_POD_NUM); +kvs_status_t helper::up_pods_count(void) { + KVS_CHECK_STATUS(get_count_names(KVS_POD_NUM, count_pods), "failed to get count names"); + return KVS_STATUS_SUCCESS; } -void helper::wait_accept(void) { +kvs_status_t helper::wait_accept(void) { char my_rank_str[MAX_KVS_VAL_LENGTH]; my_rank = 0; while (1) { - if (get_value_by_name_key(KVS_ACCEPT, my_hostname, my_rank_str) != 0) { - my_rank = safe_strtol(my_rank_str, NULL, 10); - break; - } + KVS_CHECK_STATUS(get_value_by_name_key(KVS_ACCEPT, my_hostname, my_rank_str), + "failed to get value"); + if (strlen(my_rank_str) == 0) + continue; + KVS_CHECK_STATUS(safe_strtol(my_rank_str, my_rank), "failed to convert my_rank"); + break; } + return KVS_STATUS_SUCCESS; } -void helper::clean_dead_pods_info(rank_list_t* dead_up_idx) { +kvs_status_t helper::clean_dead_pods_info(std::list& dead_up_idx) { size_t i; size_t count_death; char** kvs_keys = NULL; + auto it = dead_up_idx.begin(); - while (dead_up_idx != NULL) { - count_death = get_keys_values_by_name(KVS_APPROVED_DEAD_POD, &kvs_keys, NULL); + while (it != dead_up_idx.end()) { + KVS_CHECK_STATUS( + get_keys_values_by_name(KVS_APPROVED_DEAD_POD, &kvs_keys, NULL, count_death), + "failed to get keys and values"); for (i = 0; i < count_death; i++) { - remove_name_key(KVS_APPROVED_DEAD_POD, kvs_keys[i]); - dead_up_idx = dead_up_idx->next; - if (dead_up_idx == NULL) { + KVS_CHECK_STATUS(remove_name_key(KVS_APPROVED_DEAD_POD, kvs_keys[i]), + "failed to remove name and key"); + it++; + if (it == dead_up_idx.end()) { for (; i < count_death; i++) { free(kvs_keys[i]); } @@ -177,9 +196,10 @@ void helper::clean_dead_pods_info(rank_list_t* dead_up_idx) { } if (kvs_keys != NULL) free(kvs_keys); + return KVS_STATUS_SUCCESS; } -void helper::accept_new_ranks(shift_list_t* cur_list) { +kvs_status_t helper::accept_new_ranks(const std::list& list) { char new_rank_str[INT_STR_SIZE]; char old_rank_str[INT_STR_SIZE]; char** kvs_values = NULL; @@ -187,16 +207,19 @@ void helper::accept_new_ranks(shift_list_t* cur_list) { size_t count_values; size_t i = 0; - while (cur_list != NULL) { - if (cur_list->shift.type == CH_T_UPDATE || cur_list->shift.type == CH_T_NEW) { - SET_STR(old_rank_str, INT_STR_SIZE, RANK_TEMPLATE, cur_list->shift.old_rank); - SET_STR(new_rank_str, INT_STR_SIZE, RANK_TEMPLATE, cur_list->shift.new_rank); + for (const auto& cur_list : list) { + if (cur_list.type == CH_T_UPDATE || cur_list.type == CH_T_NEW) { + SET_STR(old_rank_str, INT_STR_SIZE, RANK_TEMPLATE, cur_list.old_rank); + SET_STR(new_rank_str, INT_STR_SIZE, RANK_TEMPLATE, cur_list.new_rank); - count_values = get_keys_values_by_name(KVS_APPROVED_NEW_POD, &kvs_keys, &kvs_values); + KVS_CHECK_STATUS( + get_keys_values_by_name(KVS_APPROVED_NEW_POD, &kvs_keys, &kvs_values, count_values), + "failed to get keys and values"); for (i = 0; i < count_values; i++) { if (!strcmp(kvs_values[i], old_rank_str)) { - set_value(KVS_ACCEPT, kvs_keys[i], new_rank_str); + KVS_CHECK_STATUS(set_value(KVS_ACCEPT, kvs_keys[i], new_rank_str), + "failed to set value"); break; } } @@ -205,22 +228,24 @@ void helper::accept_new_ranks(shift_list_t* cur_list) { free(kvs_values[i]); } } - cur_list = cur_list->next; } - while ((count_values = get_keys_values_by_name(KVS_ACCEPT, NULL, &kvs_values)) != 0) { + do { + KVS_CHECK_STATUS(get_keys_values_by_name(KVS_ACCEPT, NULL, &kvs_values, count_values), + "failed to get keys and values"); for (i = 0; i < count_values; i++) { free(kvs_values[i]); } - } + } while (count_values != 0); if (kvs_keys != NULL) free(kvs_keys); if (kvs_values != NULL) free(kvs_values); + return KVS_STATUS_SUCCESS; } -void helper::update_kvs_info(int new_rank) { +kvs_status_t helper::update_kvs_info(int new_rank) { char kvs_name[MAX_KVS_NAME_LENGTH]; char kvs_key[MAX_KVS_KEY_LENGTH]; char kvs_val[MAX_KVS_VAL_LENGTH]; @@ -230,61 +255,66 @@ void helper::update_kvs_info(int new_rank) { for (k = 0; k < kvs_list_size; k++) { cut_head(kvs_name, kvs_key, kvs_val, ST_CLIENT); - remove_name_key(kvs_name, kvs_key); + KVS_CHECK_STATUS(remove_name_key(kvs_name, kvs_key), "failed to remove name and key"); - replace_str(kvs_key, my_rank, new_rank); + KVS_CHECK_STATUS(replace_str(kvs_key, my_rank, new_rank), "failed to replace str"); - set_value(kvs_name, kvs_key, kvs_val); + KVS_CHECK_STATUS(set_value(kvs_name, kvs_key, kvs_val), "failed to set value"); put_key(kvs_name, kvs_key, kvs_val, ST_CLIENT); } + return KVS_STATUS_SUCCESS; } -void helper::move_to_new_rank(int new_rank) { +kvs_status_t helper::move_to_new_rank(int new_rank) { char rank_str[INT_STR_SIZE]; - update_kvs_info(new_rank); + KVS_CHECK_STATUS(update_kvs_info(new_rank), "failed to update kvs info"); my_rank = new_rank; - SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); + SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, new_rank); // request_set_val(KVS_POD_REQUEST, my_hostname, rank_str); - set_value(KVS_POD_NUM, rank_str, my_hostname); + KVS_CHECK_STATUS(set_value(KVS_POD_NUM, rank_str, my_hostname), "failed to update kvs info"); + return KVS_STATUS_SUCCESS; } -void helper::update_my_info(shift_list_t* list) { +kvs_status_t helper::update_my_info(const std::list& list) { char rank_str[INT_STR_SIZE]; - while (list != NULL) { - if (list->shift.old_rank == my_rank) { + for (const auto& it : list) { + if (it.old_rank == static_cast(my_rank)) { int old_rank = my_rank; - move_to_new_rank(list->shift.new_rank); + KVS_CHECK_STATUS(move_to_new_rank(it.new_rank), "failed to move to new rank"); SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, old_rank); - remove_name_key(KVS_POD_NUM, rank_str); + KVS_CHECK_STATUS(remove_name_key(KVS_POD_NUM, rank_str), + "failed to remove name and key"); break; } - list = list->next; } + return KVS_STATUS_SUCCESS; } -size_t helper::get_barrier_idx(void) { +kvs_status_t helper::get_barrier_idx(size_t& barrier_num_out) { char** kvs_values = NULL; size_t count_kvs_values = 0; size_t tmp_barrier_num; size_t min_barrier_num; size_t i = 0; - count_kvs_values = get_keys_values_by_name(KVS_BARRIER, NULL, &kvs_values); + KVS_CHECK_STATUS(get_keys_values_by_name(KVS_BARRIER, NULL, &kvs_values, count_kvs_values), + "failed to get keys and values"); if (count_kvs_values == 0) - return 0; + return KVS_STATUS_SUCCESS; - min_barrier_num = safe_strtol(kvs_values[0], NULL, 10); + KVS_CHECK_STATUS(safe_strtol(kvs_values[0], min_barrier_num), "failed to convert barrier num"); for (i = 1; i < count_kvs_values; i++) { - tmp_barrier_num = safe_strtol(kvs_values[i], NULL, 10); + KVS_CHECK_STATUS(safe_strtol(kvs_values[i], tmp_barrier_num), + "failed to convert tmp barrier num"); if (min_barrier_num > tmp_barrier_num) min_barrier_num = tmp_barrier_num; } @@ -292,10 +322,12 @@ size_t helper::get_barrier_idx(void) { free(kvs_values[i]); } free(kvs_values); - return min_barrier_num; + + barrier_num_out = min_barrier_num; + return KVS_STATUS_SUCCESS; } -void helper::post_my_info(void) { +kvs_status_t helper::post_my_info(void) { char barrier_num_str[INT_STR_SIZE]; char my_rank_str[INT_STR_SIZE]; @@ -303,106 +335,120 @@ void helper::post_my_info(void) { SET_STR(my_rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); - set_value(KVS_POD_NUM, my_rank_str, my_hostname); + KVS_CHECK_STATUS(set_value(KVS_POD_NUM, my_rank_str, my_hostname), "failed to set rank"); - barrier_num = get_barrier_idx(); + KVS_CHECK_STATUS(get_barrier_idx(barrier_num), "failed to get barrier idx"); SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num); - set_value(KVS_BARRIER, my_hostname, barrier_num_str); + KVS_CHECK_STATUS(set_value(KVS_BARRIER, my_hostname, barrier_num_str), + "failed to set barrier idx"); - remove_name_key(KVS_ACCEPT, my_hostname); + KVS_CHECK_STATUS(remove_name_key(KVS_ACCEPT, my_hostname), + "failed to remove accepted hostname"); - remove_name_key(KVS_APPROVED_NEW_POD, my_hostname); + KVS_CHECK_STATUS(remove_name_key(KVS_APPROVED_NEW_POD, my_hostname), + "failed to remove approved hostname"); barrier_num++; if (barrier_num > BARRIER_NUM_MAX) barrier_num = 0; + return KVS_STATUS_SUCCESS; } -size_t helper::update(shift_list_t** list, rank_list_t** dead_up_idx, int root_rank) { +kvs_status_t helper::update(const std::list& list, + std::list& dead_up_idx, + int root_rank) { if (applied == 1) { - if ((*list) != NULL) { - if (my_rank == root_rank) { - if ((*dead_up_idx) != NULL) - clean_dead_pods_info(*dead_up_idx); - - accept_new_ranks(*list); + if (!list.empty()) { + if (static_cast(my_rank) == root_rank) { + if (!dead_up_idx.empty()) { + KVS_CHECK_STATUS(clean_dead_pods_info(dead_up_idx), "failed to clean dead pod"); + } + KVS_CHECK_STATUS(accept_new_ranks(list), "failed to accept new ranks"); } - update_my_info(*list); + KVS_CHECK_STATUS(update_my_info(list), "failed to update info"); } } - else - post_my_info(); - - return 0; + else { + KVS_CHECK_STATUS(post_my_info(), "failed to post info"); + } + return KVS_STATUS_SUCCESS; } -size_t helper::get_val_count(const char* name, const char* val) { - size_t res = 0; +kvs_status_t helper::get_val_count(const char* name, const char* val, size_t& res) { + res = 0; char** kvs_values = NULL; size_t count_values; size_t i; - count_values = get_keys_values_by_name(name, NULL, &kvs_values); - - if (count_values == 0) - return res; + KVS_CHECK_STATUS(get_keys_values_by_name(name, NULL, &kvs_values, count_values), + "failed to get keys and values"); - for (i = 0; i < count_values; i++) { - if (!strcmp(val, kvs_values[i])) { - res++; + if (count_values != 0) { + for (i = 0; i < count_values; i++) { + if (!strcmp(val, kvs_values[i])) { + res++; + } + free(kvs_values[i]); } - free(kvs_values[i]); + free(kvs_values); } - free(kvs_values); - return res; + return KVS_STATUS_SUCCESS; } -size_t helper::get_occupied_ranks_count(char* rank) { +kvs_status_t helper::get_occupied_ranks_count(char* rank, size_t& res) { char occupied_rank_val_str[MAX_KVS_VAL_LENGTH]; size_t is_occupied_rank; size_t count_new_pod = 0; size_t count_seen_new_pod = 0; - is_occupied_rank = - (get_value_by_name_key(KVS_POD_NUM, rank, occupied_rank_val_str) == 0) ? 0 : 1; + KVS_CHECK_STATUS(get_value_by_name_key(KVS_POD_NUM, rank, occupied_rank_val_str), + "failed to get occupied rank"); + + is_occupied_rank = (strlen(occupied_rank_val_str) == 0) ? 0 : 1; - count_new_pod = get_val_count(KVS_NEW_POD, rank); + KVS_CHECK_STATUS(get_val_count(KVS_NEW_POD, rank, count_new_pod), "failed to get mew rank"); - count_seen_new_pod = get_val_count(KVS_APPROVED_NEW_POD, rank); + KVS_CHECK_STATUS(get_val_count(KVS_APPROVED_NEW_POD, rank, count_seen_new_pod), + "failed to get new approved rank"); - return is_occupied_rank + count_new_pod + count_seen_new_pod; + res = is_occupied_rank + count_new_pod + count_seen_new_pod; + return KVS_STATUS_SUCCESS; } -size_t helper::get_count_requested_ranks(char* rank) { - size_t count_pods_with_my_rank = 0; +kvs_status_t helper::get_count_requested_ranks(char* rank, size_t& count_pods_with_my_rank) { + count_pods_with_my_rank = 0; - count_pods_with_my_rank = get_val_count(KVS_POD_REQUEST, rank); + KVS_CHECK_STATUS(get_val_count(KVS_POD_REQUEST, rank, count_pods_with_my_rank), + "failed tp get requested ranks"); - return count_pods_with_my_rank; + return KVS_STATUS_SUCCESS; } -void helper::occupied_rank(char* rank) { +kvs_status_t helper::occupied_rank(char* rank) { char idx_val[MAX_KVS_VAL_LENGTH]; - size_t is_inited; - is_inited = get_value_by_name_key(KVS_UP, KVS_IDX, idx_val); + KVS_CHECK_STATUS(get_value_by_name_key(KVS_UP, KVS_IDX, idx_val), "failed to get ID"); - if ((is_inited == 0) && (my_rank == 0)) { - set_value(KVS_UP, KVS_IDX, INITIAL_UPDATE_IDX); + if ((strlen(idx_val) == 0) && (my_rank == 0)) { + KVS_CHECK_STATUS(set_value(KVS_UP, KVS_IDX, INITIAL_UPDATE_IDX), + "failed to set initial ID"); count_pods = 1; - update(NULL, NULL, 0); + std::list clear_list{}; + std::list clear_shift_list{}; + KVS_CHECK_STATUS(update(clear_shift_list, clear_list, 0), "failed to initial update"); } else { - set_value(KVS_NEW_POD, my_hostname, rank); + KVS_CHECK_STATUS(set_value(KVS_NEW_POD, my_hostname, rank), "failed to set rank"); } + return KVS_STATUS_SUCCESS; } -void helper::reg_rank(void) { +kvs_status_t helper::reg_rank(void) { char rank_str[INT_STR_SIZE]; size_t wait_shift = 0; char** kvs_values = NULL; @@ -412,7 +458,8 @@ void helper::reg_rank(void) { size_t i; my_rank = 0; - set_value(KVS_POD_REQUEST, my_hostname, INITIAL_RANK_NUM); + KVS_CHECK_STATUS(set_value(KVS_POD_REQUEST, my_hostname, INITIAL_RANK_NUM), + "failed to set initial rank"); SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); @@ -420,7 +467,9 @@ void helper::reg_rank(void) { wait_shift = 0; my_num_in_pod_request_line = 0; - count_values = get_keys_values_by_name(KVS_POD_REQUEST, &kvs_keys, &kvs_values); + KVS_CHECK_STATUS( + get_keys_values_by_name(KVS_POD_REQUEST, &kvs_keys, &kvs_values, count_values), + "failed to get requested pods"); for (i = 0; i < count_values; i++) { if (!strcmp(kvs_values[i], rank_str)) { @@ -435,13 +484,18 @@ void helper::reg_rank(void) { } if (my_num_in_pod_request_line == 1) { - if (get_occupied_ranks_count(rank_str) != 0) { + size_t rank_count; + KVS_CHECK_STATUS(get_occupied_ranks_count(rank_str, rank_count), + "failed to get occupied ranks count"); + if (rank_count != 0) { wait_shift = 0; } else { wait_shift = 1; - if (get_count_requested_ranks(rank_str) == 1) { - occupied_rank(rank_str); + KVS_CHECK_STATUS(get_count_requested_ranks(rank_str, rank_count), + "failed to get requested ranks count"); + if (rank_count == 1) { + KVS_CHECK_STATUS(occupied_rank(rank_str), "failed to get occupied rank"); break; } } @@ -450,33 +504,38 @@ void helper::reg_rank(void) { if (!wait_shift) { my_rank++; SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); - set_value(KVS_POD_REQUEST, my_hostname, rank_str); + KVS_CHECK_STATUS(set_value(KVS_POD_REQUEST, my_hostname, rank_str), + "failed to set rank"); } } - remove_name_key(KVS_POD_REQUEST, my_hostname); + KVS_CHECK_STATUS(remove_name_key(KVS_POD_REQUEST, my_hostname), "failed to remove host info"); if (kvs_keys != NULL) free(kvs_keys); if (kvs_values != NULL) free(kvs_values); + return KVS_STATUS_SUCCESS; } -size_t helper::get_replica_size(void) { - return k->kvs_get_replica_size(); +kvs_status_t helper::get_replica_size(size_t& replica_size) { + return k->kvs_get_replica_size(replica_size); } -void helper::up_kvs(const char* new_kvs_name, const char* old_kvs_name) { +kvs_status_t helper::up_kvs(const char* new_kvs_name, const char* old_kvs_name) { char** kvs_values = NULL; char** kvs_keys = NULL; size_t i = 0; size_t count_values; - count_values = get_keys_values_by_name(old_kvs_name, &kvs_keys, &kvs_values); + KVS_CHECK_STATUS(get_keys_values_by_name(old_kvs_name, &kvs_keys, &kvs_values, count_values), + "failed to get keys and values"); for (i = 0; i < count_values; i++) { - remove_name_key(old_kvs_name, kvs_keys[i]); + KVS_CHECK_STATUS(remove_name_key(old_kvs_name, kvs_keys[i]), + "failed to remove old kvs info"); - set_value(new_kvs_name, kvs_keys[i], kvs_values[i]); + KVS_CHECK_STATUS(set_value(new_kvs_name, kvs_keys[i], kvs_values[i]), + "failed to set new kvs info"); free(kvs_keys[i]); free(kvs_values[i]); @@ -485,48 +544,61 @@ void helper::up_kvs(const char* new_kvs_name, const char* old_kvs_name) { free(kvs_keys); if (kvs_values != NULL) free(kvs_values); + return KVS_STATUS_SUCCESS; } -void helper::up_kvs_new_and_dead(void) { - up_kvs(KVS_APPROVED_NEW_POD, KVS_NEW_POD); - up_kvs(KVS_APPROVED_DEAD_POD, KVS_DEAD_POD); +kvs_status_t helper::up_kvs_new_and_dead(void) { + KVS_CHECK_STATUS(up_kvs(KVS_APPROVED_NEW_POD, KVS_NEW_POD), "failed to update new"); + KVS_CHECK_STATUS(up_kvs(KVS_APPROVED_DEAD_POD, KVS_DEAD_POD), "failed to update dead"); + return KVS_STATUS_SUCCESS; } -void helper::get_new_root(int* old_root) { +kvs_status_t helper::get_new_root(int* old_root) { size_t i; char** rank_nums = NULL; - size_t rank_count = get_keys_values_by_name(KVS_DEAD_POD, NULL, &rank_nums); + size_t rank_count; + int rank_num; + KVS_CHECK_STATUS(get_keys_values_by_name(KVS_DEAD_POD, NULL, &rank_nums, rank_count), + "failed to update new"); for (i = 0; i < rank_count; i++) { - if (*old_root == (int)safe_strtol(rank_nums[i], NULL, 10)) + KVS_CHECK_STATUS(safe_strtol(rank_nums[i], rank_num), "failed to update new"); + if (*old_root == rank_num) { (*old_root)++; + } free(rank_nums[i]); } if (rank_nums != NULL) free(rank_nums); + return KVS_STATUS_SUCCESS; } -size_t helper::get_keys_values_by_name(const char* kvs_name, char*** kvs_keys, char*** kvs_values) { - return k->kvs_get_keys_values_by_name(kvs_name, kvs_keys, kvs_values); +kvs_status_t helper::get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count) { + return k->kvs_get_keys_values_by_name(kvs_name, kvs_keys, kvs_values, count); } -size_t helper::set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) { +kvs_status_t helper::set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) { return k->kvs_set_value(kvs_name, kvs_key, kvs_val); } -size_t helper::remove_name_key(const char* kvs_name, const char* kvs_key) { +kvs_status_t helper::remove_name_key(const char* kvs_name, const char* kvs_key) { return k->kvs_remove_name_key(kvs_name, kvs_key); } -size_t helper::get_value_by_name_key(const char* kvs_name, const char* kvs_key, char* kvs_val) { +kvs_status_t helper::get_value_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) { return k->kvs_get_value_by_name_key(kvs_name, kvs_key, kvs_val); } size_t helper::init(const char* main_addr) { return k->kvs_init(main_addr); } -size_t helper::main_server_address_reserve(char* main_addr) { +kvs_status_t helper::main_server_address_reserve(char* main_addr) { return k->kvs_main_server_address_reserve(main_addr); } -size_t helper::get_count_names(const char* kvs_name) { - return k->kvs_get_count_names(kvs_name); +kvs_status_t helper::get_count_names(const char* kvs_name, int& count_names) { + return k->kvs_get_count_names(kvs_name, count_names); } -size_t helper::finalize(void) { +kvs_status_t helper::finalize(void) { return k->kvs_finalize(); } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp index 8d840a17d..4758cdb66 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/helper.hpp @@ -27,7 +27,6 @@ #include #include "def.h" -#include "rank_list.hpp" #include "shift_list.hpp" #include "kvs_keeper.hpp" #include "kvs/ikvs_wrapper.h" @@ -37,10 +36,10 @@ extern size_t barrier_num; extern size_t up_idx; extern size_t applied; -extern rank_list_t* killed_ranks; +extern std::list killed_ranks; extern int killed_ranks_count; -extern rank_list_t* new_ranks; +extern std::list new_ranks; extern int new_ranks_count; class helper { @@ -49,73 +48,78 @@ class helper { explicit helper(std::shared_ptr k) : k(std::move(k)){}; ~helper() = default; - void get_update_ranks(void); + kvs_status_t get_update_ranks(void); - size_t get_replica_size(void); + kvs_status_t get_replica_size(size_t& replica_size); - void wait_accept(void); + kvs_status_t wait_accept(void); - size_t update(shift_list_t** list, rank_list_t** dead_up_idx, int root_rank); + kvs_status_t update(const std::list& list, + std::list& dead_up_idx, + int root_rank); - void up_pods_count(void); + kvs_status_t up_pods_count(void); - void get_shift(shift_list_t** list); + void get_shift(std::list& list); - void reg_rank(void); + kvs_status_t reg_rank(void); - size_t get_barrier_idx(void); + kvs_status_t get_barrier_idx(size_t& barrier_num_out); - void up_kvs_new_and_dead(void); + kvs_status_t up_kvs_new_and_dead(void); void keep_first_n_up(int prev_new_ranks_count, int prev_killed_ranks_count); - void get_new_root(int* old_root); + kvs_status_t get_new_root(int* old_root); /*Work with KVS, new*/ - size_t set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val); + kvs_status_t set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val); - size_t remove_name_key(const char* kvs_name, const char* kvs_key); + kvs_status_t remove_name_key(const char* kvs_name, const char* kvs_key); - size_t get_value_by_name_key(const char* kvs_name, const char* kvs_key, char* kvs_val); + kvs_status_t get_value_by_name_key(const char* kvs_name, const char* kvs_key, char* kvs_val); size_t init(const char* main_addr); - size_t main_server_address_reserve(char* main_addr); + kvs_status_t main_server_address_reserve(char* main_addr); - size_t get_count_names(const char* kvs_name); + kvs_status_t get_count_names(const char* kvs_name, int& count_names); - size_t finalize(void); + kvs_status_t finalize(void); - size_t get_keys_values_by_name(const char* kvs_name, char*** kvs_keys, char*** kvs_values); + kvs_status_t get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count); /*Work with KVS, new*/ private: - size_t replace_str(char* str, int old_rank, int new_rank); + kvs_status_t replace_str(char* str, int old_rank, int new_rank); - void update_ranks(int* old_count, rank_list_t** origin_list, const char* kvs_name); + kvs_status_t update_ranks(int* old_count, std::list& origin_list, const char* kvs_name); - void clean_dead_pods_info(rank_list_t* dead_up_idx); + kvs_status_t clean_dead_pods_info(std::list& dead_up_idx); - void accept_new_ranks(shift_list_t* cur_list); + kvs_status_t accept_new_ranks(const std::list& cur_list); - void update_kvs_info(int new_rank); + kvs_status_t update_kvs_info(int new_rank); - void move_to_new_rank(int new_rank); + kvs_status_t move_to_new_rank(int new_rank); - void update_my_info(shift_list_t* list); + kvs_status_t update_my_info(const std::list& list); - void post_my_info(void); + kvs_status_t post_my_info(void); - size_t get_val_count(const char* name, const char* val); + kvs_status_t get_val_count(const char* name, const char* val, size_t& res); - size_t get_occupied_ranks_count(char* rank); + kvs_status_t get_occupied_ranks_count(char* rank, size_t& res); - size_t get_count_requested_ranks(char* rank); + kvs_status_t get_count_requested_ranks(char* rank, size_t& count_pods_with_my_rank); - void occupied_rank(char* rank); + kvs_status_t occupied_rank(char* rank); - void up_kvs(const char* new_kvs_name, const char* old_kvs_name); + kvs_status_t up_kvs(const char* new_kvs_name, const char* old_kvs_name); std::shared_ptr k; }; #endif diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h index 6f68a78a6..95b3807d7 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h @@ -16,32 +16,34 @@ #pragma once #include +#include "util/pm/pmi_resizable_rt/pmi_resizable/def.h" class ikvs_wrapper { public: - virtual ~ikvs_wrapper() = default; + virtual ~ikvs_wrapper() noexcept(false){}; - virtual size_t kvs_set_value(const char* kvs_name, - const char* kvs_key, - const char* kvs_val) = 0; + virtual kvs_status_t kvs_set_value(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) = 0; - virtual size_t kvs_remove_name_key(const char* kvs_name, const char* kvs_key) = 0; + virtual kvs_status_t kvs_remove_name_key(const char* kvs_name, const char* kvs_key) = 0; - virtual size_t kvs_get_value_by_name_key(const char* kvs_name, - const char* kvs_key, - char* kvs_val) = 0; + virtual kvs_status_t kvs_get_value_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) = 0; - virtual size_t kvs_init(const char* main_addr) = 0; + virtual kvs_status_t kvs_init(const char* main_addr) = 0; - virtual size_t kvs_main_server_address_reserve(char* main_addr) = 0; + virtual kvs_status_t kvs_main_server_address_reserve(char* main_addr) = 0; - virtual size_t kvs_get_count_names(const char* kvs_name) = 0; + virtual kvs_status_t kvs_get_count_names(const char* kvs_name, int& count_names) = 0; - virtual size_t kvs_finalize(void) = 0; + virtual kvs_status_t kvs_finalize() = 0; - virtual size_t kvs_get_keys_values_by_name(const char* kvs_name, - char*** kvs_keys, - char*** kvs_values) = 0; + virtual kvs_status_t kvs_get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count) = 0; - virtual size_t kvs_get_replica_size(void) = 0; + virtual kvs_status_t kvs_get_replica_size(size_t& replica_size) = 0; }; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp index 3a69b0434..2876bebd7 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.cpp @@ -28,15 +28,15 @@ #include #include -#include "util/pm/pmi_resizable_rt/pmi_resizable/def.h" #include "internal_kvs.h" #include "internal_kvs_server.hpp" #include "common/log/log.hpp" #include "util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp" -size_t internal_kvs::kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) { +kvs_status_t internal_kvs::kvs_set_value(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_PUT; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); @@ -49,12 +49,13 @@ size_t internal_kvs::kvs_set_value(const char* kvs_name, const char* kvs_key, co client_memory_mutex, "client: put_key_value"); - return 0; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_set_size(const char* kvs_name, const char* kvs_key, const char* kvs_val) { +kvs_status_t internal_kvs::kvs_set_size(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_SET_SIZE; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); @@ -67,14 +68,13 @@ size_t internal_kvs::kvs_set_size(const char* kvs_name, const char* kvs_key, con client_memory_mutex, "client: set_size"); - return 0; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_barrier_register(const char* kvs_name, - const char* kvs_key, - const char* kvs_val) { +kvs_status_t internal_kvs::kvs_barrier_register(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_BARRIER_REGISTER; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); @@ -87,13 +87,15 @@ size_t internal_kvs::kvs_barrier_register(const char* kvs_name, client_memory_mutex, "client: barrier_register"); - return 0; + return KVS_STATUS_SUCCESS; } -void internal_kvs::kvs_barrier(const char* kvs_name, const char* kvs_key, const char* kvs_val) { +kvs_status_t internal_kvs::kvs_barrier(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) { kvs_request_t request; int is_done; - memset(&request, 0, sizeof(kvs_request_t)); + request.mode = AM_BARRIER; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); @@ -108,11 +110,11 @@ void internal_kvs::kvs_barrier(const char* kvs_name, const char* kvs_key, const sizeof(is_done), client_memory_mutex, "client: barrier read data"); + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_key) { +kvs_status_t internal_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_key) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_REMOVE; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); @@ -124,12 +126,11 @@ size_t internal_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_k client_memory_mutex, "client: remove_key"); - return 0; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_register(const char* kvs_name, const char* kvs_key, char* kvs_val) { +kvs_status_t internal_kvs::kvs_register(const char* kvs_name, const char* kvs_key, char* kvs_val) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_INTERNAL_REGISTER; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); kvs_str_copy(request.key, kvs_key, MAX_KVS_KEY_LENGTH); @@ -147,14 +148,13 @@ size_t internal_kvs::kvs_register(const char* kvs_name, const char* kvs_key, cha "client: register read data"); kvs_str_copy(kvs_val, request.val, MAX_KVS_VAL_LENGTH); - return strlen(kvs_val); + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_get_value_by_name_key(const char* kvs_name, - const char* kvs_key, - char* kvs_val) { +kvs_status_t internal_kvs::kvs_get_value_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_VAL; size_t is_exist = 0; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); @@ -181,42 +181,37 @@ size_t internal_kvs::kvs_get_value_by_name_key(const char* kvs_name, kvs_str_copy(kvs_val, request.val, MAX_KVS_VAL_LENGTH); } - return strlen(kvs_val); + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_get_count_names(const char* kvs_name) { - size_t count_names = 0; +kvs_status_t internal_kvs::kvs_get_count_names(const char* kvs_name, int& count_names) { + count_names = 0; kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_COUNT; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); - DO_RW_OP(write, - client_op_sock, - &request, - sizeof(kvs_request_t), - client_memory_mutex, - "client: get_count"); + DO_RW_OP( + write, client_op_sock, &request, sizeof(request), client_memory_mutex, "client: get_count"); DO_RW_OP(read, client_op_sock, &count_names, - sizeof(size_t), + sizeof(count_names), client_memory_mutex, "client: get_count read data"); - return count_names; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, - char*** kvs_keys, - char*** kvs_values) { - size_t count = 0; +kvs_status_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count) { + count = 0; size_t i; kvs_request_t request; - kvs_request_t* answers; + std::vector answers; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_KEYS_VALUES; kvs_str_copy(request.name, kvs_name, MAX_KVS_NAME_LENGTH); @@ -235,12 +230,12 @@ size_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, "client: get_keys_values read size"); if (count == 0) - return count; + return KVS_STATUS_SUCCESS; - answers = (kvs_request_t*)calloc(count, sizeof(kvs_request_t)); + answers.resize(count); DO_RW_OP(read, client_op_sock, - answers, + answers.data(), sizeof(kvs_request_t) * count, client_memory_mutex, "client: get_keys_values read data"); @@ -251,10 +246,14 @@ size_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, *kvs_keys = (char**)calloc(count, sizeof(char*)); if ((*kvs_keys) == nullptr) { LOG_ERROR("Memory allocation failed"); - exit(1); + return KVS_STATUS_FAILURE; } for (i = 0; i < count; i++) { (*kvs_keys)[i] = (char*)calloc(MAX_KVS_KEY_LENGTH, sizeof(char)); + if ((*kvs_keys)[i] == nullptr) { + LOG_ERROR("Memory allocation failed"); + return KVS_STATUS_FAILURE; + } kvs_str_copy((*kvs_keys)[i], answers[i].key, MAX_KVS_KEY_LENGTH); } } @@ -265,27 +264,28 @@ size_t internal_kvs::kvs_get_keys_values_by_name(const char* kvs_name, *kvs_values = (char**)calloc(count, sizeof(char*)); if ((*kvs_values) == nullptr) { LOG_ERROR("Memory allocation failed"); - exit(1); + return KVS_STATUS_FAILURE; } for (i = 0; i < count; i++) { (*kvs_values)[i] = (char*)calloc(MAX_KVS_VAL_LENGTH, sizeof(char)); + if ((*kvs_values)[i] == nullptr) { + LOG_ERROR("Memory allocation failed"); + return KVS_STATUS_FAILURE; + } kvs_str_copy((*kvs_values)[i], answers[i].val, MAX_KVS_VAL_LENGTH); } } - free(answers); - - return count; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_get_replica_size(void) { - size_t replica_size = 0; +kvs_status_t internal_kvs::kvs_get_replica_size(size_t& replica_size) { + replica_size = 0; if (ip_getting_mode == IGT_K8S) { - replica_size = request_k8s_get_replica_size(); + return request_k8s_get_replica_size(replica_size); } else { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); request.mode = AM_GET_REPLICA; DO_RW_OP(write, @@ -302,24 +302,25 @@ size_t internal_kvs::kvs_get_replica_size(void) { client_memory_mutex, "client: get_replica read size"); } - return replica_size; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::init_main_server_by_k8s() { +kvs_status_t internal_kvs::init_main_server_by_k8s() { char port_str[MAX_KVS_VAL_LENGTH]; - request_k8s_kvs_init(); + KVS_CHECK_STATUS(request_k8s_kvs_init(), "failed to init k8s kvs"); SET_STR(port_str, INT_STR_SIZE, "%d", local_server_address->get_sin_port()); - request_k8s_kvs_get_master(local_host_ip, main_host_ip, port_str); + KVS_CHECK_STATUS(request_k8s_kvs_get_master(local_host_ip, main_host_ip, port_str), + "failed to get port"); - main_port = safe_strtol(port_str, nullptr, 10); + KVS_CHECK_STATUS(safe_strtol(port_str, main_port), "failed to convert main_port"); main_server_address->set_sin_port(main_port); - main_server_address->set_sin_addr(main_host_ip); - return 0; + KVS_CHECK_STATUS(main_server_address->set_sin_addr(main_host_ip), "failed to set main_ip"); + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::init_main_server_by_env() { +kvs_status_t internal_kvs::init_main_server_by_env() { char* port = nullptr; const char* tmp_host_ip = (!server_address.empty()) ? server_address.c_str() @@ -327,7 +328,7 @@ size_t internal_kvs::init_main_server_by_env() { if (tmp_host_ip == nullptr) { LOG_ERROR("specify ", CCL_KVS_IP_PORT_ENV); - return 1; + return KVS_STATUS_FAILURE; } memset(main_host_ip, 0, CCL_IP_LEN); @@ -335,25 +336,25 @@ size_t internal_kvs::init_main_server_by_env() { if ((port = strstr(main_host_ip, "_")) == nullptr) { if ((port = strstr(main_host_ip, ":")) == nullptr) { LOG_ERROR("set ", CCL_KVS_IP_PORT_ENV, " in format _\n"); - return 1; + return KVS_STATUS_FAILURE; } } port[0] = '\0'; port++; - main_port = safe_strtol(port, nullptr, 10); + KVS_CHECK_STATUS(safe_strtol(port, main_port), "failed to convert main_port"); main_server_address->set_sin_port(main_port); - main_server_address->set_sin_addr(main_host_ip); - return 0; + KVS_CHECK_STATUS(main_server_address->set_sin_addr(main_host_ip), "failed to set main_ip"); + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::init_main_server_by_string(const char* main_addr) { +kvs_status_t internal_kvs::init_main_server_by_string(const char* main_addr) { char* port = nullptr; - local_server_address->set_sin_addr(local_host_ip); + KVS_CHECK_STATUS(local_server_address->set_sin_addr(local_host_ip), "failed to set main_ip"); if ((server_listen_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { - LOG_ERROR("init_main_server_by_string: server_listen_sock init"); - exit(EXIT_FAILURE); + LOG_ERROR("server_listen_sock init"); + return KVS_STATUS_FAILURE; } size_t sin_port = local_server_address->get_sin_port(); @@ -369,28 +370,28 @@ size_t internal_kvs::init_main_server_by_string(const char* main_addr) { if ((port = strstr(main_host_ip, "_")) == nullptr) { if ((port = strstr(main_host_ip, ":")) == nullptr) { - LOG_ERROR( - "init_main_server_by_string: set ", CCL_KVS_IP_PORT_ENV, " in format _"); - return 1; + LOG_ERROR("set ", CCL_KVS_IP_PORT_ENV, " in format _"); + return KVS_STATUS_FAILURE; } } port[0] = '\0'; port++; - main_port = safe_strtol(port, nullptr, 10); + KVS_CHECK_STATUS(safe_strtol(port, main_port), "failed to convert main_port"); main_server_address->set_sin_port(main_port); - main_server_address->set_sin_addr(main_host_ip); - return 0; + KVS_CHECK_STATUS(main_server_address->set_sin_addr(main_host_ip), "failed to set main_ip"); + + return KVS_STATUS_SUCCESS; } -int internal_kvs::fill_local_host_ip() { +kvs_status_t internal_kvs::fill_local_host_ip() { struct ifaddrs *ifaddr, *ifa; int family = AF_UNSPEC; char local_ip[CCL_IP_LEN]; bool is_supported_iface = false; if (getifaddrs(&ifaddr) < 0) { - LOG_ERROR("fill_local_host_ip: can not get host IP"); - return -1; + LOG_ERROR("can not get host IP"); + return KVS_STATUS_FAILURE; } const char iface_name[] = "lo"; @@ -421,10 +422,10 @@ int internal_kvs::fill_local_host_ip() { 0, NI_NUMERICHOST); if (res != 0) { - std::string s("fill_local_host_ip: getnameinfo error > "); + std::string s("getnameinfo error > "); s.append(gai_strerror(res)); LOG_ERROR(s.c_str()); - return -1; + return KVS_STATUS_FAILURE; } local_host_ips.push_back(local_ip); @@ -443,16 +444,18 @@ int internal_kvs::fill_local_host_ip() { } } if (local_host_ips.empty()) { - LOG_ERROR("fill_local_host_ip: can't find interface ", - iface_name_env ? iface_name_env : "", - " to get host IP"); - return -1; + LOG_ERROR("can't find interface ", iface_name_env ? iface_name_env : "", " to get host IP"); + return KVS_STATUS_FAILURE; } memset(local_host_ip, 0, CCL_IP_LEN); char* kvs_prefer_ipv6 = std::getenv(CCL_KVS_PREFER_IPV6_ENV.c_str()); - size_t is_kvs_prefer_ipv6 = kvs_prefer_ipv6 ? safe_strtol(kvs_prefer_ipv6, nullptr, 10) : 0; + size_t is_kvs_prefer_ipv6 = 0; + if (kvs_prefer_ipv6) { + KVS_CHECK_STATUS(safe_strtol(kvs_prefer_ipv6, is_kvs_prefer_ipv6), + "failed to set prefer_ip6"); + } if (is_kvs_prefer_ipv6) { if (!local_host_ipv6s.empty()) { @@ -480,25 +483,22 @@ int internal_kvs::fill_local_host_ip() { LOG_DEBUG("use ", address_family == AF_INET ? "ipv4" : "ipv6", ": ", local_host_ip); freeifaddrs(ifaddr); - return 0; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { +kvs_status_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { if (!server_address.empty()) - return 0; + return KVS_STATUS_SUCCESS; - if (fill_local_host_ip() < 0) { - LOG_ERROR("reserve_main_address: failed to get local host IP"); - exit(EXIT_FAILURE); - } + KVS_CHECK_STATUS(fill_local_host_ip(), "failed to get local host IP"); if ((server_listen_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { - LOG_ERROR("reserve_main_address: server_listen_sock init"); - exit(EXIT_FAILURE); + LOG_ERROR("server_listen_sock init"); + return KVS_STATUS_FAILURE; } - main_server_address->set_sin_addr(local_host_ip); - local_server_address->set_sin_addr(local_host_ip); + KVS_CHECK_STATUS(main_server_address->set_sin_addr(local_host_ip), "failed to set local_ip"); + KVS_CHECK_STATUS(local_server_address->set_sin_addr(local_host_ip), "failed to set local_ip"); size_t sin_port = main_server_address->get_sin_port(); while (bind(server_listen_sock, @@ -516,17 +516,14 @@ size_t internal_kvs::kvs_main_server_address_reserve(char* main_address) { "_%d", main_server_address->get_sin_port()); - return 0; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::init_main_server_address(const char* main_addr) { +kvs_status_t internal_kvs::init_main_server_address(const char* main_addr) { char* ip_getting_type = std::getenv(CCL_KVS_IP_EXCHANGE_ENV.c_str()); if (local_host_ips.empty()) { - if (fill_local_host_ip() < 0) { - LOG_ERROR("init_main_server_address: failed to get local host ip"); - exit(EXIT_FAILURE); - } + KVS_CHECK_STATUS(fill_local_host_ip(), "failed to get local host ip"); } if (ip_getting_type) { @@ -538,28 +535,29 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { } else { LOG_ERROR("unknown ", CCL_KVS_IP_EXCHANGE_ENV, ": ", ip_getting_type); - return 1; + return KVS_STATUS_FAILURE; } } if (server_address.empty()) { if (main_addr != NULL) { ip_getting_mode = IGT_ENV; - if (server_listen_sock == 0) - init_main_server_by_string(main_addr); - return 0; + if (server_listen_sock == 0) { + KVS_CHECK_STATUS(init_main_server_by_string(main_addr), + "failed to init main server"); + } + return KVS_STATUS_SUCCESS; } } else { ip_getting_mode = IGT_ENV; } - local_server_address->set_sin_addr(local_host_ip); + KVS_CHECK_STATUS(local_server_address->set_sin_addr(local_host_ip), "failed to set local_ip"); if ((server_listen_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { - ; - LOG_ERROR("init_main_server_address: server_listen_sock init"); - exit(EXIT_FAILURE); + LOG_ERROR("server_listen_sock init"); + return KVS_STATUS_FAILURE; } switch (ip_getting_mode) { @@ -576,11 +574,9 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { return init_main_server_by_k8s(); } case IGT_ENV: { - int res = init_main_server_by_env(); int is_master_node = 0; - if (res) - return res; + KVS_CHECK_STATUS(init_main_server_by_env(), "failed to init_main_server_by_env"); if (strstr(local_host_ip, main_host_ip)) { is_master_node = 1; @@ -592,7 +588,8 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { is_master_node = 1; memset(local_host_ip, 0, CCL_IP_LEN); kvs_str_copy_known_sizes(local_host_ip, main_host_ip, CCL_IP_LEN); - local_server_address->set_sin_addr(local_host_ip); + KVS_CHECK_STATUS(local_server_address->set_sin_addr(local_host_ip), + "get sin add failed"); } } if (is_master_node) { @@ -622,16 +619,16 @@ size_t internal_kvs::init_main_server_address(const char* main_addr) { } } - return res; + return KVS_STATUS_SUCCESS; } default: { LOG_ERROR("unknown ", CCL_KVS_IP_EXCHANGE_ENV); - return 1; + return KVS_STATUS_FAILURE; } } } -size_t internal_kvs::kvs_init(const char* main_addr) { +kvs_status_t internal_kvs::kvs_init(const char* main_addr) { int err; socklen_t len = 0; std::shared_ptr addr; @@ -639,32 +636,32 @@ size_t internal_kvs::kvs_init(const char* main_addr) { time_t start_time; time_t connection_time = 0; - if (init_main_server_address(main_addr)) { - LOG_ERROR("kvs_init: init main server address error"); + if (init_main_server_address(main_addr) != KVS_STATUS_SUCCESS) { + LOG_ERROR("init main server address error"); close(client_op_sock); close(server_control_sock); client_op_sock = 0; server_control_sock = 0; - return 1; + return KVS_STATUS_FAILURE; } if (address_family == AF_INET) { addr = std::shared_ptr(new sockaddr_v4()); - addr->set_sin_addr("127.0.0.1"); + KVS_CHECK_STATUS(addr->set_sin_addr("127.0.0.1"), "failed to set sin_addr(\"127.0.0.1\""); } else { addr = std::shared_ptr(new sockaddr_v6()); - addr->set_sin_addr("::1"); + KVS_CHECK_STATUS(addr->set_sin_addr("::1"), "failed to set sin_addr(\"::1\""); } if ((client_op_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { - LOG_ERROR("kvs_init: client_op_sock init"); - return 1; + LOG_ERROR("client_op_sock init"); + return KVS_STATUS_FAILURE; } if ((server_control_sock = socket(address_family, SOCK_STREAM, 0)) < 0) { - LOG_ERROR("kvs_init: server_control_sock init"); - return 1; + LOG_ERROR("server_control_sock init"); + return KVS_STATUS_FAILURE; } size_t sin_port = addr->get_sin_port(); @@ -674,8 +671,8 @@ size_t internal_kvs::kvs_init(const char* main_addr) { } if (listen(server_control_sock, 1) < 0) { - LOG_ERROR("kvs_init: server_control_sock listen"); - exit(EXIT_FAILURE); + LOG_ERROR("server_control_sock listen"); + return KVS_STATUS_FAILURE; } getsockname(server_control_sock, addr->get_sock_addr_ptr(), &len); @@ -684,13 +681,13 @@ size_t internal_kvs::kvs_init(const char* main_addr) { args.sock_listener = server_listen_sock; err = pthread_create(&kvs_thread, nullptr, kvs_server_init, &args); if (err) { - LOG_ERROR("kvs_init: failed to create kvs server thread, pthread_create returns ", err); - return 1; + LOG_ERROR("failed to create kvs server thread, pthread_create returns ", err); + return KVS_STATUS_FAILURE; } if ((client_control_sock = accept(server_control_sock, nullptr, nullptr)) < 0) { - LOG_ERROR("kvs_init: server_control_sock accept"); - exit(EXIT_FAILURE); + LOG_ERROR("server_control_sock accept"); + return KVS_STATUS_FAILURE; } /* Wait connection to master */ @@ -702,11 +699,8 @@ size_t internal_kvs::kvs_init(const char* main_addr) { } while ((err < 0) && (connection_time < CONNECTION_TIMEOUT)); if (connection_time >= CONNECTION_TIMEOUT) { - LOG_ERROR("kvs_init: connection error: timeout limit (", - connection_time, - " > ", - CONNECTION_TIMEOUT); - exit(EXIT_FAILURE); + LOG_ERROR("connection time (", connection_time, ") >= limit (", CONNECTION_TIMEOUT, ")"); + return KVS_STATUS_FAILURE; } if (strstr(main_host_ip, local_host_ip) && local_port == main_port) { @@ -714,13 +708,11 @@ size_t internal_kvs::kvs_init(const char* main_addr) { } is_inited = true; - return 0; + return KVS_STATUS_SUCCESS; } -size_t internal_kvs::kvs_finalize(void) { +kvs_status_t internal_kvs::kvs_finalize(void) { kvs_request_t request; - memset(&request, 0, sizeof(kvs_request_t)); - close(client_op_sock); client_op_sock = 0; if (kvs_thread != 0) { @@ -743,7 +735,8 @@ size_t internal_kvs::kvs_finalize(void) { err = pthread_join(kvs_thread, &exit_code); if (err) { - LOG_ERROR("kvs_finalize: failed to stop kvs server thread, pthread_join returns ", err); + LOG_ERROR("failed to stop kvs server thread, pthread_join returns ", err); + return KVS_STATUS_FAILURE; } kvs_thread = 0; @@ -755,19 +748,21 @@ size_t internal_kvs::kvs_finalize(void) { server_control_sock = 0; } - if (ip_getting_mode == IGT_K8S) - request_k8s_kvs_finalize(is_master); + if (ip_getting_mode == IGT_K8S) { + KVS_CHECK_STATUS(request_k8s_kvs_finalize(is_master), "failed to finaluze k8s kvs"); + } is_inited = false; - return 0; + return KVS_STATUS_SUCCESS; } internal_kvs::~internal_kvs() { - if (is_inited) - kvs_finalize(); + if (is_inited) { + CCL_THROW_IF_NOT(kvs_finalize() == KVS_STATUS_SUCCESS, "failed to finalize kvs"); + } } -void sockaddr_v4::set_sin_addr(const char* src) { +kvs_status_t sockaddr_v4::set_sin_addr(const char* src) { int ret = inet_pton(addr.sin_family, src, &(addr.sin_addr)); if (ret <= 0) { if (ret == 0) { @@ -782,17 +777,19 @@ void sockaddr_v4::set_sin_addr(const char* src) { ", error: ", strerror(errno)); } - exit(1); + return KVS_STATUS_FAILURE; } + return KVS_STATUS_SUCCESS; } -void sockaddr_v6::set_sin_addr(const char* src) { +kvs_status_t sockaddr_v6::set_sin_addr(const char* src) { char src_copy[internal_kvs::CCL_IP_LEN] = { 0 }; kvs_str_copy(src_copy, src, internal_kvs::CCL_IP_LEN); char* scope_id_ptr = nullptr; if ((scope_id_ptr = strchr(src_copy, internal_kvs::SCOPE_ID_DELIM))) { - addr.sin6_scope_id = safe_strtol(scope_id_ptr + 1, nullptr, 10); + KVS_CHECK_STATUS(safe_strtol(scope_id_ptr + 1, addr.sin6_scope_id), + "failed to ged sin6_id"); *scope_id_ptr = '\0'; } @@ -812,8 +809,9 @@ void sockaddr_v6::set_sin_addr(const char* src) { ", error: ", strerror(errno)); } - exit(1); + return KVS_STATUS_FAILURE; } - LOG_DEBUG("addr: ", src_copy, ", scope_id: ", addr.sin6_scope_id); + LOG_DEBUG("", src_copy, ", scope_id: ", addr.sin6_scope_id); + return KVS_STATUS_SUCCESS; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h index 7460426d2..01a56fffa 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs.h @@ -28,7 +28,7 @@ class isockaddr { virtual in_port_t get_sin_port() = 0; virtual void set_sin_port(in_port_t) = 0; virtual const void* get_sin_addr_ptr() = 0; - virtual void set_sin_addr(const char*) = 0; + virtual kvs_status_t set_sin_addr(const char*) = 0; virtual struct sockaddr* get_sock_addr_ptr() = 0; virtual sa_family_t sin_family() = 0; virtual size_t size() = 0; @@ -40,35 +40,40 @@ class isockaddr { class internal_kvs final : public ikvs_wrapper { public: - size_t kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) override; + kvs_status_t kvs_set_value(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) override; - size_t kvs_remove_name_key(const char* kvs_name, const char* kvs_key) override; + kvs_status_t kvs_remove_name_key(const char* kvs_name, const char* kvs_key) override; - size_t kvs_get_value_by_name_key(const char* kvs_name, - const char* kvs_key, - char* kvs_val) override; + kvs_status_t kvs_get_value_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) override; - size_t kvs_register(const char* kvs_name, const char* kvs_key, char* kvs_val); + kvs_status_t kvs_register(const char* kvs_name, const char* kvs_key, char* kvs_val); - size_t kvs_set_size(const char* kvs_name, const char* kvs_key, const char* kvs_val); + kvs_status_t kvs_set_size(const char* kvs_name, const char* kvs_key, const char* kvs_val); - size_t kvs_barrier_register(const char* kvs_name, const char* kvs_key, const char* kvs_val); + kvs_status_t kvs_barrier_register(const char* kvs_name, + const char* kvs_key, + const char* kvs_val); - void kvs_barrier(const char* kvs_name, const char* kvs_key, const char* kvs_val); + kvs_status_t kvs_barrier(const char* kvs_name, const char* kvs_key, const char* kvs_val); - size_t kvs_init(const char* main_addr) override; + kvs_status_t kvs_init(const char* main_addr) override; - size_t kvs_main_server_address_reserve(char* main_addr) override; + kvs_status_t kvs_main_server_address_reserve(char* main_addr) override; - size_t kvs_get_count_names(const char* kvs_name) override; + kvs_status_t kvs_get_count_names(const char* kvs_name, int& count_names) override; - size_t kvs_finalize() override; + kvs_status_t kvs_finalize() override; - size_t kvs_get_keys_values_by_name(const char* kvs_name, - char*** kvs_keys, - char*** kvs_values) override; + kvs_status_t kvs_get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count) override; - size_t kvs_get_replica_size() override; + kvs_status_t kvs_get_replica_size(size_t& replica_size) override; ~internal_kvs() override; @@ -80,11 +85,11 @@ class internal_kvs final : public ikvs_wrapper { static const char SCOPE_ID_DELIM = '%'; private: - size_t init_main_server_by_string(const char* main_addr); - size_t init_main_server_by_env(); - size_t init_main_server_by_k8s(); - size_t init_main_server_address(const char* main_addr); - int fill_local_host_ip(); + kvs_status_t init_main_server_by_string(const char* main_addr); + kvs_status_t init_main_server_by_env(); + kvs_status_t init_main_server_by_k8s(); + kvs_status_t init_main_server_address(const char* main_addr); + kvs_status_t fill_local_host_ip(); bool is_inited{ false }; pthread_t kvs_thread = 0; @@ -151,7 +156,7 @@ class sockaddr_v4 : public isockaddr { const void* get_sin_addr_ptr() override { return &(addr.sin_addr); } - void set_sin_addr(const char* src) override; + kvs_status_t set_sin_addr(const char* src) override; sa_family_t sin_family() override { return addr.sin_family; } @@ -180,7 +185,7 @@ class sockaddr_v6 : public isockaddr { const void* get_sin_addr_ptr() override { return &(addr.sin6_addr); } - void set_sin_addr(const char* src) override; + kvs_status_t set_sin_addr(const char* src) override; struct sockaddr* get_sock_addr_ptr() override { return (struct sockaddr*)&addr; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp index 7a4cba47f..6ac605e6a 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.cpp @@ -32,10 +32,10 @@ class server { public: server() = default; - void run(void*); - bool check_finalize(); - void make_client_request(int& socket); - void try_to_connect_new(); + kvs_status_t run(void*); + kvs_status_t check_finalize(bool& to_finalize); + kvs_status_t make_client_request(int& socket); + kvs_status_t try_to_connect_new(); private: struct clients_info { @@ -82,7 +82,7 @@ class server { sa_family_t address_family{ AF_UNSPEC }; }; -void server::try_to_connect_new() { +kvs_status_t server::try_to_connect_new() { if (poll_fds[FDI_LISTENER].revents != 0) { std::shared_ptr addr; @@ -98,8 +98,8 @@ void server::try_to_connect_new() { if ((new_socket = accept(poll_fds[FDI_LISTENER].fd, addr->get_sock_addr_ptr(), (socklen_t*)&peer_addr_size)) < 0) { - perror("server: server_listen_sock accept"); - exit(EXIT_FAILURE); + LOG_ERROR("server_listen_sock accept, %s", strerror(errno)); + return KVS_STATUS_FAILURE; } for (size_t i = FDI_LAST; i < poll_fds.size(); i++) { if (poll_fds[i].fd == free_socket) { @@ -117,16 +117,17 @@ void server::try_to_connect_new() { } } } + return KVS_STATUS_SUCCESS; } -void server::make_client_request(int& socket) { +kvs_status_t server::make_client_request(int& socket) { DO_RW_OP_1( read, socket, &request, sizeof(kvs_request_t), ret, "server: get command from client"); if (ret == 0) { close(socket); socket = free_socket; client_count--; - return; + return KVS_STATUS_SUCCESS; } switch (request.mode) { @@ -184,8 +185,10 @@ void server::make_client_request(int& socket) { } case AM_GET_REPLICA: { char* replica_size_str = getenv(CCL_WORLD_SIZE_ENV); - count = (replica_size_str != nullptr) ? safe_strtol(replica_size_str, nullptr, 10) - : client_count; + count = client_count; + if (replica_size_str != nullptr) { + KVS_CHECK_STATUS(safe_strtol(replica_size_str, count), "failed to convert count"); + } DO_RW_OP( write, socket, &count, sizeof(size_t), server_memory_mutex, "server: get_replica"); break; @@ -236,8 +239,8 @@ void server::make_client_request(int& socket) { }); if (client_it == clients.end()) { // TODO: Look deeper to fix this error - printf("Server error: Unregister Barrier request!"); - exit(1); + LOG_ERROR("Server error: Unregister Barrier request!"); + return KVS_STATUS_FAILURE; } auto client_inf = client_it->get(); client_inf->in_barrier = true; @@ -275,9 +278,14 @@ void server::make_client_request(int& socket) { else { local_size[0] = '\0'; local_size++; - barrier.local_size += safe_strtol(local_size, nullptr, 10); + size_t local_size_tmp; + + KVS_CHECK_STATUS(safe_strtol(local_size, local_size_tmp), + "failed to convert local_size"); + barrier.local_size += local_size_tmp; } - barrier.global_size = safe_strtol(glob_size, nullptr, 10); + KVS_CHECK_STATUS(safe_strtol(glob_size, barrier.global_size), + "failed to convert global_size"); barrier.clients.push_back( std::shared_ptr(new clients_info(socket, false))); @@ -285,7 +293,8 @@ void server::make_client_request(int& socket) { } case AM_SET_SIZE: { char* glob_size = request.val; - communicators[request.key].global_size = safe_strtol(glob_size, nullptr, 10); + KVS_CHECK_STATUS(safe_strtol(glob_size, communicators[request.key].global_size), + "failed to convert global_size"); break; } @@ -301,7 +310,9 @@ void server::make_client_request(int& socket) { char* thread_id = strstr(proc_id, "_"); thread_id[0] = '\0'; thread_id++; - size_t rank_count = safe_strtol(rank_count_str, nullptr, 10); + size_t rank_count; + KVS_CHECK_STATUS(safe_strtol(rank_count_str, rank_count), + "failed to convert rank_count"); communicators[request.key].local_size += rank_count; socket_info sock_info{ socket, proc_id, { rank, rank_count, thread_id } }; communicator.processes[proc_id].push_back(sock_info.process_info); @@ -351,15 +362,16 @@ void server::make_client_request(int& socket) { } default: { if (request.name[0] == '\0') - return; - printf("server: unknown request mode - %d.\n", request.mode); - exit(EXIT_FAILURE); + return KVS_STATUS_SUCCESS; + LOG_ERROR("unknown request mode - %d.\n", request.mode); + return KVS_STATUS_FAILURE; } } + return KVS_STATUS_SUCCESS; } -bool server::check_finalize() { - bool to_finalize = false; +kvs_status_t server::check_finalize(bool& to_finalize) { + to_finalize = false; if (poll_fds[FDI_CONTROL].revents != 0) { DO_RW_OP_1(read, poll_fds[FDI_CONTROL].fd, @@ -372,15 +384,15 @@ bool server::check_finalize() { poll_fds[FDI_CONTROL].fd = free_socket; } if (request.mode != AM_FINALIZE) { - printf("server: invalid access mode for local socket\n"); - exit(EXIT_FAILURE); + LOG_ERROR("invalid access mode for local socket\n"); + return KVS_STATUS_FAILURE; } to_finalize = true; } - return to_finalize; + return KVS_STATUS_SUCCESS; } -void server::run(void* args) { +kvs_status_t server::run(void* args) { bool should_stop = false; int so_reuse = 1; poll_fds.resize(client_count_increase); @@ -398,13 +410,13 @@ void server::run(void* args) { #endif if (listen(poll_fds[FDI_LISTENER].fd, max_client_queue_size) < 0) { - LOG_ERROR("server: server_listen_sock listen"); - exit(EXIT_FAILURE); + LOG_ERROR("server_listen_sock listen(%s)", strerror(errno)); + return KVS_STATUS_FAILURE; } if ((poll_fds[FDI_CONTROL].fd = socket(address_family, SOCK_STREAM, 0)) < 0) { - perror("server: server_control_sock init"); - exit(EXIT_FAILURE); + LOG_ERROR("server_control_sock init(%s)", strerror(errno)); + return KVS_STATUS_FAILURE; } while (connect(poll_fds[FDI_CONTROL].fd, @@ -414,8 +426,8 @@ void server::run(void* args) { while (!should_stop || client_count > 0) { if (poll(poll_fds.data(), poll_fds.size(), -1) < 0) { if (errno != EINTR) { - perror("server: poll"); - exit(EXIT_FAILURE); + LOG_ERROR("poll(%s)", strerror(errno)); + return KVS_STATUS_FAILURE; } else { /* restart select */ @@ -425,12 +437,12 @@ void server::run(void* args) { for (size_t i = FDI_LAST; i < poll_fds.size(); i++) { if (poll_fds[i].fd != free_socket && poll_fds[i].revents != 0) { - make_client_request(poll_fds[i].fd); + KVS_CHECK_STATUS(make_client_request(poll_fds[i].fd), "failed to make request"); } } - try_to_connect_new(); + KVS_CHECK_STATUS(try_to_connect_new(), "failed to connect new"); if (!should_stop) { - should_stop = check_finalize(); + KVS_CHECK_STATUS(check_finalize(should_stop), "failed to check finalize"); } } @@ -455,12 +467,15 @@ void server::run(void* args) { close(poll_fds[FDI_LISTENER].fd); poll_fds[FDI_LISTENER].fd = free_socket; + return KVS_STATUS_SUCCESS; } void* kvs_server_init(void* args) { server s; - s.run(args); + if (s.run(args) != KVS_STATUS_SUCCESS) { + LOG_ERROR("failed"); + } return nullptr; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp index 12590aa54..2c8d17cd5 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/internal_kvs_server.hpp @@ -32,10 +32,10 @@ typedef enum kvs_access_mode { } kvs_access_mode_t; typedef struct kvs_request { - kvs_access_mode_t mode; - char name[MAX_KVS_NAME_LENGTH]; - char key[MAX_KVS_KEY_LENGTH]; - char val[MAX_KVS_VAL_LENGTH]; + kvs_access_mode_t mode{ AM_PUT }; + char name[MAX_KVS_NAME_LENGTH]{}; + char key[MAX_KVS_KEY_LENGTH]{}; + char val[MAX_KVS_VAL_LENGTH]{}; } kvs_request_t; typedef struct server_args { diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp index e648f1fbc..ff039a5f3 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.cpp @@ -20,70 +20,80 @@ users_kvs::users_kvs(std::shared_ptr kvs) : kvs(kvs) {} -size_t users_kvs::kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) { +kvs_status_t users_kvs::kvs_set_value(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) { ccl::string_class name(kvs_name), key(kvs_key); ccl::vector_class vec_val(kvs_val, kvs_val + strlen(kvs_val) + 1); vec_val[strlen(kvs_val)] = '\0'; kvs->set(name + key, vec_val); - return 0; + return KVS_STATUS_SUCCESS; } -size_t users_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_key) { +kvs_status_t users_kvs::kvs_remove_name_key(const char* kvs_name, const char* kvs_key) { ccl::vector_class kvs_val = { '\0' }; ccl::string_class name(kvs_name), key(kvs_key); kvs->set(name + key, kvs_val); - return 0; + return KVS_STATUS_SUCCESS; } -size_t users_kvs::kvs_get_value_by_name_key(const char* kvs_name, - const char* kvs_key, - char* kvs_val) { +kvs_status_t users_kvs::kvs_get_value_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) { ccl::string_class name(kvs_name), key(kvs_key); ccl::vector_class res = kvs->get(name + key); + memset(kvs_val, 0, MAX_KVS_VAL_LENGTH); if (res.data()) SET_STR(kvs_val, MAX_KVS_VAL_LENGTH, "%s", res.data()); else SET_STR(kvs_val, MAX_KVS_VAL_LENGTH, "%s", ""); - return strlen(kvs_val); + return KVS_STATUS_SUCCESS; } -size_t users_kvs::kvs_get_count_names(const char* kvs_name) { +kvs_status_t users_kvs::kvs_get_count_names(const char* kvs_name, int& count_names) { /*TODO: Unsupported*/ (void)kvs_name; - return 0; + LOG_ERROR("unsupported"); + return KVS_STATUS_UNSUPPORTED; } -size_t users_kvs::kvs_get_keys_values_by_name(const char* kvs_name, - char*** kvs_keys, - char*** kvs_values) { +kvs_status_t users_kvs::kvs_get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count) { /*TODO: Unsupported*/ (void)kvs_name; (void)kvs_keys; (void)kvs_values; - return 0; + LOG_ERROR("unsupported"); + return KVS_STATUS_UNSUPPORTED; } -size_t users_kvs::kvs_get_replica_size(void) { +kvs_status_t users_kvs::kvs_get_replica_size(size_t& replica_size) { /*TODO: Unsupported*/ - return 0; + LOG_ERROR("unsupported"); + return KVS_STATUS_UNSUPPORTED; } -size_t users_kvs::kvs_main_server_address_reserve(char* main_address) { +kvs_status_t users_kvs::kvs_main_server_address_reserve(char* main_address) { /*TODO: Unsupported*/ (void)main_address; - return 0; + LOG_ERROR("unsupported"); + return KVS_STATUS_UNSUPPORTED; } -size_t users_kvs::kvs_init(const char* main_addr) { +kvs_status_t users_kvs::kvs_init(const char* main_addr) { /*TODO: Unsupported*/ (void)main_addr; - return 0; + LOG_ERROR("unsupported"); + return KVS_STATUS_UNSUPPORTED; } -size_t users_kvs::kvs_finalize(void) { +kvs_status_t users_kvs::kvs_finalize(void) { /*TODO: Unsupported*/ - return 0; + LOG_ERROR("unsupported"); + return KVS_STATUS_UNSUPPORTED; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h index 1d220ebff..6d180e764 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h @@ -27,27 +27,30 @@ class users_kvs final : public ikvs_wrapper { ~users_kvs() = default; - size_t kvs_set_value(const char* kvs_name, const char* kvs_key, const char* kvs_val) override; + kvs_status_t kvs_set_value(const char* kvs_name, + const char* kvs_key, + const char* kvs_val) override; - size_t kvs_remove_name_key(const char* kvs_name, const char* kvs_key) override; + kvs_status_t kvs_remove_name_key(const char* kvs_name, const char* kvs_key) override; - size_t kvs_get_value_by_name_key(const char* kvs_name, - const char* kvs_key, - char* kvs_val) override; + kvs_status_t kvs_get_value_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) override; - size_t kvs_init(const char* main_addr) override; + kvs_status_t kvs_init(const char* main_addr) override; - size_t kvs_main_server_address_reserve(char* main_addr) override; + kvs_status_t kvs_main_server_address_reserve(char* main_addr) override; - size_t kvs_get_count_names(const char* kvs_name) override; + kvs_status_t kvs_get_count_names(const char* kvs_name, int& count_names) override; - size_t kvs_finalize(void) override; + kvs_status_t kvs_finalize(void) override; - size_t kvs_get_keys_values_by_name(const char* kvs_name, - char*** kvs_keys, - char*** kvs_values) override; + kvs_status_t kvs_get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + size_t& count) override; - size_t kvs_get_replica_size(void) override; + kvs_status_t kvs_get_replica_size(size_t& replica_size) override; private: std::shared_ptr kvs; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp index 6f1ea920b..cb2e4d157 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.cpp @@ -29,11 +29,6 @@ #define LISTENER_TIMEOUT 5 -enum return_status { - get_new = 0, - timeout = 1, -}; - static int sock_sender; static size_t num_listeners; static int sock_listener = -1; @@ -45,10 +40,10 @@ void pmi_listener::set_applied_count(int count) { num_changes -= count; } -int pmi_listener::collect_sock_addr(std::shared_ptr h) { +kvs_status_t pmi_listener::collect_sock_addr(std::shared_ptr h) { FILE* fp; size_t i, j; - int res = 0; + kvs_status_t res = KVS_STATUS_SUCCESS; size_t glob_num_listeners; char** sock_addr_str = NULL; char** hosts_names_str = NULL; @@ -56,8 +51,8 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { char* point_to_space; if ((fp = popen(GET_IP_CMD, READ_ONLY)) == NULL) { - printf("Can't get host IP\n"); - exit(1); + LOG_ERROR("Can't get host IP"); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(my_ip, MAX_KVS_VAL_LENGTH, fp), my_ip); pclose(fp); @@ -66,7 +61,9 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { if ((point_to_space = strstr(my_ip, " ")) != NULL) point_to_space[0] = NULL_CHAR; - glob_num_listeners = h->get_keys_values_by_name(KVS_LISTENER, &hosts_names_str, &sock_addr_str); + KVS_CHECK_STATUS(h->get_keys_values_by_name( + KVS_LISTENER, &hosts_names_str, &sock_addr_str, glob_num_listeners), + "failed to get sock info"); num_listeners = glob_num_listeners; for (i = 0; i < num_listeners; i++) { @@ -77,13 +74,13 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { } if (num_listeners == 0) { - res = 0; + res = KVS_STATUS_SUCCESS; goto exit; } if ((sock_sender = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { - printf("\n Socket creation error \n"); - res = -1; + LOG_ERROR("Socket creation error"); + res = KVS_STATUS_FAILURE; goto exit; } @@ -93,8 +90,8 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { server_addresses = (struct sockaddr_in*)malloc((num_listeners) * sizeof(struct sockaddr_in)); if (server_addresses == NULL) { - printf("\nmemory allocation failed \n"); - res = -1; + LOG_ERROR("nmemory allocation failed"); + res = KVS_STATUS_FAILURE; goto exit; } @@ -102,8 +99,8 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { for (i = 0, j = 0; i < num_listeners; i++, j++) { char* point_to_port = strstr(sock_addr_str[j], "_"); if (point_to_port == NULL) { - printf("\nlistener: Wrong address_port record: %s\n", sock_addr_str[j]); - res = -1; + LOG_ERROR("Wrong address_port record: %s", sock_addr_str[j]); + res = KVS_STATUS_FAILURE; goto exit; } point_to_port[0] = NULL_CHAR; @@ -113,12 +110,16 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { continue; } - server_addresses[i].sin_port = safe_strtol(point_to_port, NULL, 10); + if (safe_strtol(point_to_port, server_addresses[i].sin_port) != KVS_STATUS_SUCCESS) { + LOG_ERROR("failed to convert sin_port"); + res = KVS_STATUS_FAILURE; + goto exit; + } server_addresses[i].sin_family = AF_INET; if (inet_pton(AF_INET, sock_addr_str[j], &(server_addresses[i].sin_addr)) <= 0) { - printf("\nlist: Invalid address/ Address not supported: %s\n", sock_addr_str[j]); - res = -1; + LOG_ERROR("Invalid address/ Address not supported: %s", sock_addr_str[j]); + res = KVS_STATUS_FAILURE; goto exit; } } @@ -132,16 +133,17 @@ int pmi_listener::collect_sock_addr(std::shared_ptr h) { return res; } -void pmi_listener::clean_listener(std::shared_ptr h) { - h->remove_name_key(KVS_LISTENER, my_hostname); +kvs_status_t pmi_listener::clean_listener(std::shared_ptr h) { + KVS_CHECK_STATUS(h->remove_name_key(KVS_LISTENER, my_hostname), "failed to remove host info"); close(sock_listener); + return KVS_STATUS_SUCCESS; } -void pmi_listener::send_notification(int sig, std::shared_ptr h) { +kvs_status_t pmi_listener::send_notification(int sig, std::shared_ptr h) { size_t i; char message[INT_STR_SIZE]; - collect_sock_addr(h); + KVS_CHECK_STATUS(collect_sock_addr(h), "failed to collect sock info"); SET_STR(message, INT_STR_SIZE, "%s", "Update!"); for (i = 0; i < num_listeners; ++i) { @@ -152,11 +154,13 @@ void pmi_listener::send_notification(int sig, std::shared_ptr h) { (const struct sockaddr*)&(server_addresses[i]), sizeof(server_addresses[i])); } - if (sig) - clean_listener(h); + if (sig) { + KVS_CHECK_STATUS(clean_listener(h), "failed to clean listener"); + } + return KVS_STATUS_SUCCESS; } -int pmi_listener::run_listener(std::shared_ptr h) { +kvs_status_t pmi_listener::run_listener(std::shared_ptr h) { socklen_t len = 0; char recv_buf[INT_STR_SIZE]; memset(recv_buf, 0, INT_STR_SIZE); @@ -181,8 +185,10 @@ int pmi_listener::run_listener(std::shared_ptr h) { my_ip[strlen(my_ip) - 1] = '\0'; if ((point_to_space = strstr(my_ip, " ")) != NULL) point_to_space[0] = NULL_CHAR; - if ((sock_listener = socket(AF_INET, SOCK_DGRAM, 0)) < 0) - return 1; + if ((sock_listener = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + LOG_ERROR("socket error(%s)", strerror(errno)); + return KVS_STATUS_FAILURE; + } memset(&addr, 0, sizeof(addr)); @@ -190,14 +196,17 @@ int pmi_listener::run_listener(std::shared_ptr h) { addr.sin_addr.s_addr = INADDR_ANY; addr.sin_port = 0; - if (bind(sock_listener, (const struct sockaddr*)&addr, sizeof(addr)) < 0) - return 1; + if (bind(sock_listener, (const struct sockaddr*)&addr, sizeof(addr)) < 0) { + LOG_ERROR("bind error(%s)", strerror(errno)); + return KVS_STATUS_FAILURE; + } getsockname(sock_listener, (struct sockaddr*)&addr, (socklen_t*)&addr_len); SET_STR( addr_for_kvs, REQUEST_POSTFIX_SIZE, KVS_NAME_TEMPLATE_I, my_ip, (size_t)addr.sin_port); - h->set_value(KVS_LISTENER, my_hostname, addr_for_kvs); + KVS_CHECK_STATUS(h->set_value(KVS_LISTENER, my_hostname, addr_for_kvs), + "failed to set addr info"); if (setsockopt(sock_listener, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { perror("Error"); } @@ -213,14 +222,15 @@ int pmi_listener::run_listener(std::shared_ptr h) { &len); if (ret == -1) { if (errno == EAGAIN) { - return timeout; + return KVS_STATUS_SUCCESS; } if (errno != EINTR) { - printf("listner: accept error: %s\n", strerror(errno)); + LOG_ERROR("listner: accept error: %s\n", strerror(errno)); + return KVS_STATUS_FAILURE; } } num_changes++; } - return get_new; + return KVS_STATUS_SUCCESS; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp index e845ea50d..7616dafbb 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/pmi_listener.hpp @@ -20,14 +20,14 @@ class pmi_listener { public: - void send_notification(int sig, std::shared_ptr h); + kvs_status_t send_notification(int sig, std::shared_ptr h); void set_applied_count(int count); - int run_listener(std::shared_ptr h); + kvs_status_t run_listener(std::shared_ptr h); private: - int collect_sock_addr(std::shared_ptr h); - void clean_listener(std::shared_ptr h); + kvs_status_t collect_sock_addr(std::shared_ptr h); + kvs_status_t clean_listener(std::shared_ptr h); }; #endif diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp deleted file mode 100644 index 712c69562..000000000 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.cpp +++ /dev/null @@ -1,104 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include -#include - -#include "rank_list.hpp" - -void rank_list_sort(rank_list_t* list) { - rank_list_t* left = list; - rank_list_t* right; - - while (left != NULL) { - right = left->next; - while (right != NULL) { - if (left->rank > right->rank) { - int tmp_i = left->rank; - left->rank = right->rank; - right->rank = tmp_i; - } - right = right->next; - } - left = left->next; - } -} - -void rank_list_clean(rank_list_t** list) { - rank_list_t* cur_list = *list; - rank_list_t* node_to_remove; - - while (cur_list != NULL) { - node_to_remove = cur_list; - cur_list = cur_list->next; - free(node_to_remove); - } - *list = NULL; -} - -size_t rank_list_contains(rank_list_t* list, int rank) { - rank_list_t* cur_list = list; - - while (cur_list != NULL) { - if (cur_list->rank == rank) - return 1; - cur_list = cur_list->next; - } - return 0; -} - -void rank_list_keep_first_n(rank_list_t** origin_list, size_t n) { - rank_list_t* cur_node = (*origin_list); - rank_list_t* tmp_node = NULL; - size_t i; - - for (i = 0; i < n; i++) { - tmp_node = cur_node; - cur_node = cur_node->next; - } - - if (tmp_node != NULL) - tmp_node->next = NULL; - - while (cur_node != NULL) { - tmp_node = cur_node; - cur_node = cur_node->next; - free(tmp_node); - } - if (n == 0) - (*origin_list) = NULL; -} - -void rank_list_add(rank_list_t** origin_list, int rank) { - if ((*origin_list) == NULL) { - (*origin_list) = (rank_list_t*)malloc(sizeof(rank_list_t)); - if ((*origin_list) == NULL) { - printf("Memory allocation failed\n"); - return; - } - (*origin_list)->next = NULL; - (*origin_list)->rank = rank; - } - else { - rank_list_t* cur_list; - cur_list = (*origin_list); - while (cur_list->next != NULL) - cur_list = cur_list->next; - cur_list->next = (rank_list_t*)malloc(sizeof(rank_list_t)); - cur_list = cur_list->next; - cur_list->next = NULL; - cur_list->rank = rank; - } -} diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.hpp deleted file mode 100644 index 064e244d6..000000000 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/rank_list.hpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifndef INT_LIST_H_INCLUDED -#define INT_LIST_H_INCLUDED - -#ifdef __cplusplus -extern "C" { -#endif -typedef struct rank_list { - int rank; - struct rank_list* next; -} rank_list_t; - -size_t rank_list_contains(rank_list_t* list, int rank); - -void rank_list_clean(rank_list_t** list); - -void rank_list_sort(rank_list_t* list); - -void rank_list_keep_first_n(rank_list_t** origin_list, size_t n); - -void rank_list_add(rank_list_t** origin_list, int rank); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp index 7d525b7e5..6886e4732 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.cpp @@ -55,6 +55,14 @@ char master_addr[MAX_KVS_NAME_LENGTH]; #define GET_KEY "| sed -r 's/\"[a-zA-Z0-9_]*-|: \"[a-zA-Z0-9_-]*|,|\"| |//g'" #define GET_VAL "| sed -r 's/[a-zA-Z0-9_-]*\":|,|\"| |//g'" +#define CHECK_STR(expr, str) \ + do { \ + if (!(expr)) { \ + LOG_ERROR("wrong str: ", str); \ + return KVS_STATUS_FAILURE; \ + } \ + } while (0) + char run_get_template[RUN_TEMPLATE_SIZE]; char run_set_template[RUN_TEMPLATE_SIZE]; char job_name[MAX_KVS_NAME_LENGTH]; @@ -66,19 +74,22 @@ typedef enum manager_type { manager_type_t manager; -size_t request_k8s_get_keys_values_by_name(const char* kvs_name, - char*** kvs_key, - char*** kvs_values); +kvs_status_t request_k8s_get_keys_values_by_name(const char* kvs_name, + char*** kvs_key, + char*** kvs_values, + int& values_count); -size_t request_k8s_get_count_names(const char* kvs_name); +kvs_status_t request_k8s_get_count_names(const char* kvs_name, size_t& res); -size_t request_k8s_get_val_by_name_key(const char* kvs_name, const char* kvs_key, char* kvs_val); +kvs_status_t request_k8s_get_val_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val); -size_t request_k8s_remove_name_key(const char* kvs_name, const char* kvs_key); +kvs_status_t request_k8s_remove_name_key(const char* kvs_name, const char* kvs_key); -size_t request_k8s_set_val(const char* kvs_name, const char* kvs_key, const char* kvs_val); +kvs_status_t request_k8s_set_val(const char* kvs_name, const char* kvs_key, const char* kvs_val); -void json_get_val(FILE* fp, const char** keys, size_t keys_count, char* val) { +kvs_status_t json_get_val(FILE* fp, const char** keys, size_t keys_count, char* val) { char cur_kvs_str[MAX_KVS_STR_LENGTH]; char* res; char last_char; @@ -101,26 +112,33 @@ void json_get_val(FILE* fp, const char** keys, size_t keys_count, char* val) { wrong_namespace_depth--; } } - res = strstr(cur_kvs_str, ":"); - res++; - while (res[0] == ' ') + CHECK_STR(res = strstr(cur_kvs_str, ":"), cur_kvs_str); + do { res++; + CHECK_STR(res, cur_kvs_str); + } while (res[0] == ' '); - if (res[0] == '"' || res[0] == '\'') + if (res[0] == '"' || res[0] == '\'') { res++; + CHECK_STR(res, cur_kvs_str); + } - last_char = res[strlen(res) - 1]; + int str_len = strlen(res) - 1; + last_char = res[str_len]; while (last_char == '\n' || last_char == ',' || last_char == ' ' || last_char == '"' || last_char == ' ') { - res[strlen(res) - 1] = '\0'; - last_char = res[strlen(res) - 1]; + res[str_len] = '\0'; + str_len--; + CHECK_STR(str_len, cur_kvs_str); + last_char = res[str_len]; } kvs_str_copy(val, res, MAX_KVS_VAL_LENGTH); while (fgets(cur_kvs_str, MAX_KVS_STR_LENGTH, fp)) { } + return KVS_STATUS_SUCCESS; } -size_t k8s_init_with_manager() { +kvs_status_t k8s_init_with_manager() { FILE* fp; FILE* fp_name; FILE* fp_type; @@ -137,16 +155,21 @@ size_t k8s_init_with_manager() { char pod_name[MAX_KVS_VAL_LENGTH]; memset(pod_name, '\0', MAX_KVS_VAL_LENGTH); if ((fp = popen("hostname", READ_ONLY)) == NULL) { - printf("Can't get hostname\n"); - exit(1); + LOG_ERROR("Can't get hostname\n"); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(pod_name, MAX_KVS_VAL_LENGTH, fp), pod_name); pclose(fp); - while (pod_name[strlen(pod_name) - 1] == '\n' || pod_name[strlen(pod_name) - 1] == ' ') - pod_name[strlen(pod_name) - 1] = '\0'; + int str_len = strlen(pod_name) - 1; + CHECK_STR(str_len, "hostname"); + while (pod_name[str_len] == '\n' || pod_name[str_len] == ' ') { + pod_name[str_len] = '\0'; + str_len--; + CHECK_STR(str_len, "hostname"); + } if (kube_api_addr == NULL) { - printf("%s not set\n", CCL_K8S_API_ADDR_ENV); - return 1; + LOG_ERROR("%s not set\n", CCL_K8S_API_ADDR_ENV); + return KVS_STATUS_FAILURE; } SET_STR(connect_api_template, RUN_TEMPLATE_SIZE, ADDR_STR_V1_TEMPLATE, kube_api_addr); @@ -156,10 +179,10 @@ size_t k8s_init_with_manager() { memset(kind_type, NULL_CHAR, MAX_KVS_NAME_LENGTH); if ((fp_name = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get kind_type\n"); - exit(1); + LOG_ERROR("Can't get kind_type\n"); + return KVS_STATUS_FAILURE; } - json_get_val(fp_name, kind_type_key, 3, kind_type); + KVS_CHECK_STATUS(json_get_val(fp_name, kind_type_key, 3, kind_type), "failed to get type"); /*we must use the plural to access to statefulset/deployment KVS*/ kind_type_size = strlen(kind_type); @@ -170,10 +193,10 @@ size_t k8s_init_with_manager() { memset(kind_name, NULL_CHAR, MAX_KVS_NAME_LENGTH); if ((fp_type = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get kind_name\n"); - exit(1); + LOG_ERROR("Can't get kind_name\n"); + return KVS_STATUS_FAILURE; } - json_get_val(fp_type, kind_name_key, 3, kind_name); + KVS_CHECK_STATUS(json_get_val(fp_type, kind_name_key, 3, kind_name), "filed to get name"); SET_STR(kind_path, MAX_KVS_NAME_LENGTH, "%s/%s", kind_type, kind_name); SET_STR(connect_api_template, RUN_TEMPLATE_SIZE, ADDR_STR_V2_TEMPLATE, kube_api_addr); @@ -193,10 +216,10 @@ size_t k8s_init_with_manager() { pclose(fp_name); pclose(fp_type); - return 0; + return KVS_STATUS_SUCCESS; } -void get_my_job_name(const char* connect_api_template) { +kvs_status_t get_my_job_name(const char* connect_api_template) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char grep_kvs_name_key[REQUEST_POSTFIX_SIZE]; @@ -204,13 +227,18 @@ void get_my_job_name(const char* connect_api_template) { char pod_name[MAX_KVS_VAL_LENGTH]; memset(pod_name, '\0', MAX_KVS_VAL_LENGTH); if ((fp = popen("hostname", READ_ONLY)) == NULL) { - printf("Can't get hostname\n"); - exit(1); + LOG_ERROR("Can't get hostname\n"); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(pod_name, MAX_KVS_VAL_LENGTH, fp), pod_name); pclose(fp); - while (pod_name[strlen(pod_name) - 1] == '\n' || pod_name[strlen(pod_name) - 1] == ' ') - pod_name[strlen(pod_name) - 1] = '\0'; + int str_len = strlen(pod_name) - 1; + CHECK_STR(str_len, "hostname"); + while (pod_name[str_len] == '\n' || pod_name[str_len] == ' ') { + pod_name[str_len] = '\0'; + str_len--; + CHECK_STR(str_len, "hostname"); + } SET_STR(grep_kvs_name_key, REQUEST_POSTFIX_SIZE, GREP_TEMPLATE, JOB_NAME); SET_STR( @@ -224,8 +252,8 @@ void get_my_job_name(const char* connect_api_template) { get_kvs_val); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get %s", strerror(errno)); - exit(1); + LOG_ERROR("Can't get %s", strerror(errno)); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(job_name, MAX_KVS_NAME_LENGTH, fp), job_name); pclose(fp); @@ -236,26 +264,32 @@ void get_my_job_name(const char* connect_api_template) { else { job_name[strlen(job_name) - 1] = '_'; } + return KVS_STATUS_SUCCESS; } -size_t k8s_init_without_manager() { +kvs_status_t k8s_init_without_manager() { FILE* fp; char* kube_api_addr = getenv(CCL_K8S_API_ADDR_ENV); char connect_api_template[RUN_TEMPLATE_SIZE]; char pod_name[MAX_KVS_VAL_LENGTH]; memset(pod_name, '\0', MAX_KVS_VAL_LENGTH); if ((fp = popen("hostname", READ_ONLY)) == NULL) { - printf("Can't get hostname\n"); - exit(1); + LOG_ERROR("Can't get hostname\n"); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(pod_name, MAX_KVS_VAL_LENGTH, fp), pod_name); pclose(fp); - while (pod_name[strlen(pod_name) - 1] == '\n' || pod_name[strlen(pod_name) - 1] == ' ') - pod_name[strlen(pod_name) - 1] = '\0'; + int str_len = strlen(pod_name) - 1; + CHECK_STR(str_len, "hostname"); + while (pod_name[str_len] == '\n' || pod_name[str_len] == ' ') { + pod_name[str_len] = '\0'; + str_len--; + CHECK_STR(str_len, "hostname"); + } if (kube_api_addr == NULL) { - printf("%s not set\n", CCL_K8S_API_ADDR_ENV); - return 1; + LOG_ERROR("%s not set\n", CCL_K8S_API_ADDR_ENV); + return KVS_STATUS_FAILURE; } SET_STR(connect_api_template, RUN_TEMPLATE_SIZE, ADDR_STR_V1_TEMPLATE, kube_api_addr); @@ -272,13 +306,12 @@ size_t k8s_init_without_manager() { pod_name, "%s"); - get_my_job_name(connect_api_template); + KVS_CHECK_STATUS(get_my_job_name(connect_api_template), "failed to get job name"); - return 0; + return KVS_STATUS_SUCCESS; } -size_t request_k8s_kvs_init() { - size_t res = 1; +kvs_status_t request_k8s_kvs_init() { char* manager_type_env = getenv(CCL_K8S_MANAGER_TYPE_ENV); if (!manager_type_env || strstr(manager_type_env, "none")) { @@ -288,7 +321,7 @@ size_t request_k8s_kvs_init() { manager = MT_K8S; } else { - printf( + LOG_WARN( "Unknown %s = %s, running with \"none\"\n", CCL_K8S_MANAGER_TYPE_ENV, manager_type_env); manager = MT_NONE; } @@ -296,8 +329,11 @@ size_t request_k8s_kvs_init() { memset(job_name, NULL_CHAR, MAX_KVS_NAME_LENGTH); switch (manager) { - case MT_NONE: res = k8s_init_without_manager(); break; - case MT_K8S: res = k8s_init_with_manager(); break; + case MT_NONE: + KVS_CHECK_STATUS(k8s_init_without_manager(), "failed to initialize k8z"); + break; + case MT_K8S: KVS_CHECK_STATUS(k8s_init_with_manager(), "failed to initialize k8z"); break; + default: LOG_ERROR("unknown k8s manager"); return KVS_STATUS_FAILURE; } memset(ccl_kvs_ip, NULL_CHAR, MAX_KVS_NAME_LENGTH); @@ -310,35 +346,51 @@ size_t request_k8s_kvs_init() { SET_STR(req_kvs_ip, MAX_KVS_NAME_LENGTH, KVS_NAME_TEMPLATE_S, job_name, REQ_KVS_IP); SET_STR(master_addr, MAX_KVS_NAME_LENGTH, KVS_NAME_TEMPLATE_S, job_name, MASTER_ADDR); - return res; + return KVS_STATUS_SUCCESS; } -size_t request_k8s_kvs_get_master(const char* local_host_ip, char* main_host_ip, char* port_str) { +kvs_status_t request_k8s_kvs_get_master(const char* local_host_ip, + char* main_host_ip, + char* port_str) { char** kvs_values = NULL; char** kvs_keys = NULL; int values_count = 0; - request_k8s_set_val(ccl_kvs_ip, my_hostname, local_host_ip); - request_k8s_set_val(ccl_kvs_port, my_hostname, port_str); - - if (!request_k8s_get_count_names(master_addr)) { - values_count = request_k8s_get_keys_values_by_name(ccl_kvs_ip, &kvs_keys, &kvs_values); + KVS_CHECK_STATUS(request_k8s_set_val(ccl_kvs_ip, my_hostname, local_host_ip), + "failed to set IP"); + KVS_CHECK_STATUS(request_k8s_set_val(ccl_kvs_port, my_hostname, port_str), + "failed to set port"); + size_t count; + KVS_CHECK_STATUS(request_k8s_get_count_names(master_addr, count), "failed to get names count"); + if (count == 0) { + KVS_CHECK_STATUS( + request_k8s_get_keys_values_by_name(ccl_kvs_ip, &kvs_keys, &kvs_values, values_count), + "failed to get keys"); if (strstr(kvs_keys[0], my_hostname)) { - request_k8s_set_val(req_kvs_ip, my_hostname, local_host_ip); - while (!request_k8s_get_count_names(master_addr)) { - values_count = - request_k8s_get_keys_values_by_name(req_kvs_ip, &kvs_keys, &kvs_values); + KVS_CHECK_STATUS(request_k8s_set_val(req_kvs_ip, my_hostname, local_host_ip), + "failed to set IP"); + KVS_CHECK_STATUS(request_k8s_get_count_names(master_addr, count), + "failed to get names count"); + while (count == 0) { + KVS_CHECK_STATUS(request_k8s_get_keys_values_by_name( + req_kvs_ip, &kvs_keys, &kvs_values, values_count), + "failed to get keys values"); if (values_count > 1) { if (!strstr(kvs_keys[0], my_hostname)) { break; } } else { - request_k8s_set_val(master_addr, KVS_IP, local_host_ip); - request_k8s_set_val(master_addr, KVS_PORT, port_str); + KVS_CHECK_STATUS(request_k8s_set_val(master_addr, KVS_IP, local_host_ip), + "failed to set IP"); + KVS_CHECK_STATUS(request_k8s_set_val(master_addr, KVS_PORT, port_str), + "failed to set port"); } + KVS_CHECK_STATUS(request_k8s_get_count_names(master_addr, count), + "failed to get names count"); } - request_k8s_remove_name_key(req_kvs_ip, my_hostname); + KVS_CHECK_STATUS(request_k8s_remove_name_key(req_kvs_ip, my_hostname), + "failed to remove host info"); } if (kvs_keys != NULL) { for (int i = 0; i < values_count; i++) { @@ -353,29 +405,36 @@ size_t request_k8s_kvs_get_master(const char* local_host_ip, char* main_host_ip, free(kvs_values); } } - while (!request_k8s_get_count_names(master_addr)) { + do { + KVS_CHECK_STATUS(request_k8s_get_count_names(master_addr, count), + "failed to get names count"); sleep(1); - } - request_k8s_get_val_by_name_key(master_addr, KVS_IP, main_host_ip); - request_k8s_get_val_by_name_key(master_addr, KVS_PORT, port_str); - return 0; + } while (count == 0); + KVS_CHECK_STATUS(request_k8s_get_val_by_name_key(master_addr, KVS_IP, main_host_ip), + "failed to get IP"); + KVS_CHECK_STATUS(request_k8s_get_val_by_name_key(master_addr, KVS_PORT, port_str), + "failed to get port"); + return KVS_STATUS_SUCCESS; } -size_t request_k8s_kvs_finalize(size_t is_master) { - request_k8s_remove_name_key(ccl_kvs_ip, my_hostname); - request_k8s_remove_name_key(ccl_kvs_port, my_hostname); +kvs_status_t request_k8s_kvs_finalize(size_t is_master) { + KVS_CHECK_STATUS(request_k8s_remove_name_key(ccl_kvs_ip, my_hostname), "failed to remove IP"); + KVS_CHECK_STATUS(request_k8s_remove_name_key(ccl_kvs_port, my_hostname), + "failed to remove port"); if (is_master) { - request_k8s_remove_name_key(master_addr, KVS_IP); - request_k8s_remove_name_key(master_addr, KVS_PORT); + KVS_CHECK_STATUS(request_k8s_remove_name_key(master_addr, KVS_IP), + "failed to remove master IP"); + KVS_CHECK_STATUS(request_k8s_remove_name_key(master_addr, KVS_PORT), + "failed to remove master IP"); } - return 0; + return KVS_STATUS_SUCCESS; } -size_t get_by_template(char*** kvs_entry, - const char* request, - const char* template_str, - int count, - int max_count) { +kvs_status_t get_by_template(char*** kvs_entry, + const char* request, + const char* template_str, + int count, + int max_count) { FILE* fp; char get_val[REQUEST_POSTFIX_SIZE]; char run_str[RUN_REQUEST_SIZE]; @@ -386,14 +445,14 @@ size_t get_by_template(char*** kvs_entry, *kvs_entry = (char**)malloc(sizeof(char*) * count); if (*kvs_entry == NULL) { - printf("Memory allocation failed\n"); - exit(1); + LOG_ERROR("Memory allocation failed\n"); + return KVS_STATUS_FAILURE; } for (i = 0; i < count; i++) { (*kvs_entry)[i] = (char*)malloc(sizeof(char) * max_count); if ((*kvs_entry)[i] == NULL) { - printf("Memory allocation failed\n"); - exit(1); + LOG_ERROR("Memory allocation failed\n"); + return KVS_STATUS_FAILURE; } } @@ -402,8 +461,8 @@ size_t get_by_template(char*** kvs_entry, SET_STR(get_val, REQUEST_POSTFIX_SIZE, CONCAT_TWO_COMMAND_TEMPLATE, request, template_str); SET_STR(run_str, RUN_REQUEST_SIZE, run_get_template, get_val); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get by template\n"); - exit(1); + LOG_ERROR("Can't get by template\n"); + return KVS_STATUS_FAILURE; } while ((fgets((*kvs_entry)[i], max_count, fp) != NULL) && (i < count)) { while ((*kvs_entry)[i][strlen((*kvs_entry)[i]) - 1] == '\n' || @@ -412,18 +471,19 @@ size_t get_by_template(char*** kvs_entry, i++; } pclose(fp); - return 0; + return KVS_STATUS_SUCCESS; } -size_t request_k8s_get_keys_values_by_name(const char* kvs_name, - char*** kvs_keys, - char*** kvs_values) { +kvs_status_t request_k8s_get_keys_values_by_name(const char* kvs_name, + char*** kvs_keys, + char*** kvs_values, + int& values_count) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char grep_name_str[REQUEST_POSTFIX_SIZE]; char get_name_count[REQUEST_POSTFIX_SIZE]; char values_count_str[INT_STR_SIZE]; - size_t values_count; + values_count = 0; SET_STR(get_name_count, REQUEST_POSTFIX_SIZE, GREP_COUNT_TEMPLATE, kvs_name); @@ -431,30 +491,36 @@ size_t request_k8s_get_keys_values_by_name(const char* kvs_name, SET_STR(run_str, RUN_REQUEST_SIZE, run_get_template, get_name_count); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get keys-values by name: %s\n", kvs_name); - exit(1); + LOG_ERROR("Can't get keys-values by name: %s\n", kvs_name); + return KVS_STATUS_SUCCESS; } CHECK_FGETS(fgets(values_count_str, INT_STR_SIZE, fp), values_count_str); pclose(fp); - if ((values_count = safe_strtol(values_count_str, NULL, 10)) == 0) - return 0; + KVS_CHECK_STATUS(safe_strtol(values_count_str, values_count), "failed to convert count"); + if (values_count == 0) + return KVS_STATUS_SUCCESS; SET_STR(grep_name_str, REQUEST_POSTFIX_SIZE, GREP_TEMPLATE, kvs_name); if (kvs_values != NULL) { - get_by_template(kvs_values, grep_name_str, GET_VAL, values_count, MAX_KVS_VAL_LENGTH); + KVS_CHECK_STATUS( + get_by_template(kvs_values, grep_name_str, GET_VAL, values_count, MAX_KVS_VAL_LENGTH), + "failed to get val"); } if (kvs_keys != NULL) { - get_by_template(kvs_keys, grep_name_str, GET_KEY, values_count, MAX_KVS_KEY_LENGTH); + KVS_CHECK_STATUS( + get_by_template(kvs_keys, grep_name_str, GET_KEY, values_count, MAX_KVS_KEY_LENGTH), + "failed to get key"); } - return values_count; + return KVS_STATUS_SUCCESS; } -size_t request_k8s_get_count_names(const char* kvs_name) { +kvs_status_t request_k8s_get_count_names(const char* kvs_name, size_t& res) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char get_count_str[REQUEST_POSTFIX_SIZE]; char count_names[INT_STR_SIZE]; + res = 0; SET_STR(get_count_str, REQUEST_POSTFIX_SIZE, GREP_COUNT_TEMPLATE, kvs_name); @@ -462,16 +528,19 @@ size_t request_k8s_get_count_names(const char* kvs_name) { SET_STR(run_str, RUN_REQUEST_SIZE, run_get_template, get_count_str); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get names count: %s\n", kvs_name); - exit(1); + LOG_ERROR("Can't get names count: %s\n", kvs_name); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(count_names, INT_STR_SIZE, fp), count_names); pclose(fp); - return safe_strtol(count_names, NULL, 10); + KVS_CHECK_STATUS(safe_strtol(count_names, res), "failed to convert cont names"); + return KVS_STATUS_SUCCESS; } -size_t request_k8s_get_val_by_name_key(const char* kvs_name, const char* kvs_key, char* kvs_val) { +kvs_status_t request_k8s_get_val_by_name_key(const char* kvs_name, + const char* kvs_key, + char* kvs_val) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char grep_kvs_name_key[REQUEST_POSTFIX_SIZE]; @@ -487,16 +556,16 @@ size_t request_k8s_get_val_by_name_key(const char* kvs_name, const char* kvs_key SET_STR(run_str, RUN_REQUEST_SIZE, run_get_template, get_kvs_val); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't get value by name-key: %s\n", kvs_name_key); - exit(1); + LOG_ERROR("Can't get value by name-key: %s\n", kvs_name_key); + return KVS_STATUS_FAILURE; } CHECK_FGETS(fgets(kvs_val, MAX_KVS_VAL_LENGTH, fp), kvs_val); pclose(fp); kvs_val[strlen(kvs_val) - 1] = NULL_CHAR; - return strlen(kvs_val); + return KVS_STATUS_SUCCESS; } -size_t request_k8s_remove_name_key(const char* kvs_name, const char* kvs_key) { +kvs_status_t request_k8s_remove_name_key(const char* kvs_name, const char* kvs_key) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char patch[REQUEST_POSTFIX_SIZE]; @@ -509,14 +578,14 @@ size_t request_k8s_remove_name_key(const char* kvs_name, const char* kvs_key) { SET_STR(run_str, RUN_REQUEST_SIZE, run_set_template, patch); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't remove name-key: %s\n", kvs_name_key); - exit(1); + LOG_ERROR("Can't remove name-key: %s\n", kvs_name_key); + return KVS_STATUS_FAILURE; } pclose(fp); - return 0; + return KVS_STATUS_SUCCESS; } -size_t request_k8s_set_val(const char* kvs_name, const char* kvs_key, const char* kvs_val) { +kvs_status_t request_k8s_set_val(const char* kvs_name, const char* kvs_key, const char* kvs_val) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char patch[REQUEST_POSTFIX_SIZE]; @@ -529,21 +598,21 @@ size_t request_k8s_set_val(const char* kvs_name, const char* kvs_key, const char SET_STR(run_str, RUN_REQUEST_SIZE, run_set_template, patch); if ((fp = popen(run_str, READ_ONLY)) == NULL) { - printf("Can't set name-key-val: %s-%s\n", kvs_name_key, kvs_val); - exit(1); + LOG_ERROR("Can't set name-key-val: %s-%s\n", kvs_name_key, kvs_val); + return KVS_STATUS_FAILURE; } pclose(fp); - return 0; + return KVS_STATUS_SUCCESS; } -size_t request_k8s_get_replica_size(void) { +kvs_status_t request_k8s_get_replica_size(size_t& res) { FILE* fp; char run_str[RUN_REQUEST_SIZE]; char replica_size_str[MAX_KVS_VAL_LENGTH]; const char* replica_keys[] = { "spec", "replicas" }; switch (manager) { - case MT_NONE: return request_k8s_get_count_names(ccl_kvs_ip); + case MT_NONE: return request_k8s_get_count_names(ccl_kvs_ip, res); case MT_K8S: /*get full output*/ SET_STR(run_str, RUN_REQUEST_SIZE, run_get_template, ""); @@ -552,9 +621,12 @@ size_t request_k8s_get_replica_size(void) { printf("Can't get replica size\n"); exit(1); } - json_get_val(fp, replica_keys, 2, replica_size_str); + KVS_CHECK_STATUS(json_get_val(fp, replica_keys, 2, replica_size_str), + "failed to get replica size"); pclose(fp); - return safe_strtol(replica_size_str, NULL, 10); + KVS_CHECK_STATUS(safe_strtol(replica_size_str, res), "failed to convert replica size"); + return KVS_STATUS_SUCCESS; + default: LOG_ERROR("unknown k8s manager"); return KVS_STATUS_FAILURE; } - return 0; + return KVS_STATUS_SUCCESS; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp index 86bdb7705..f9eaf6400 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/request_wrappers_k8s.hpp @@ -21,13 +21,15 @@ extern "C" { #endif #include -size_t request_k8s_kvs_init(void); +kvs_status_t request_k8s_kvs_init(void); -size_t request_k8s_kvs_get_master(const char* local_host_ip, char* main_host_ip, char* port_str); +kvs_status_t request_k8s_kvs_get_master(const char* local_host_ip, + char* main_host_ip, + char* port_str); -size_t request_k8s_kvs_finalize(size_t is_master); +kvs_status_t request_k8s_kvs_finalize(size_t is_master); -size_t request_k8s_get_replica_size(void); +kvs_status_t request_k8s_get_replica_size(size_t& res); #ifdef __cplusplus } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp index 085bf0e8d..e5266535d 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.cpp @@ -27,21 +27,31 @@ char my_hostname[MAX_KVS_VAL_LENGTH]; // TODO: rework it for multi kvs static pmi_resizable* pmi_object; -void Call_Hard_finilize(int sig) { - pmi_object->Hard_finilize(sig); +void call_hard_finalize(int sig) { + if (pmi_object->hard_finalize(sig) != KVS_STATUS_SUCCESS) { + LOG_ERROR("failed to hard finalize"); + } } kvs_resize_action_t pmi_resizable::default_checker(int comm_size) { char* comm_size_to_start_env; - int comm_size_to_start; + size_t comm_size_to_start; comm_size_to_start_env = getenv(CCL_WORLD_SIZE_ENV); - if (comm_size_to_start_env != NULL) - comm_size_to_start = safe_strtol(comm_size_to_start_env, NULL, 10); - else - comm_size_to_start = h->get_replica_size(); - if (comm_size >= comm_size_to_start) + if (comm_size_to_start_env != NULL) { + if (safe_strtol(comm_size_to_start_env, comm_size_to_start) != KVS_STATUS_SUCCESS) { + LOG_ERROR("failed to convert comm_size"); + return KVS_RA_FINALIZE; + } + } + else { + if (h->get_replica_size(comm_size_to_start) != KVS_STATUS_SUCCESS) { + LOG_ERROR("failed to get comm_size"); + return KVS_RA_FINALIZE; + } + } + if (comm_size >= static_cast(comm_size_to_start)) return KVS_RA_RUN; return KVS_RA_WAIT; @@ -54,27 +64,29 @@ kvs_resize_action_t pmi_resizable::call_resize_fn(int comm_size) { return default_checker(comm_size); } -int pmi_resizable::PMIR_Update(void) { +kvs_status_t pmi_resizable::PMIR_Update(void) { char up_idx_str[MAX_KVS_VAL_LENGTH]; int prev_new_ranks_count = 0; int prev_killed_ranks_count = 0; int prev_idx = -1; kvs_resize_action_t answer; - rank_list_t* dead_up_idx = NULL; - shift_list_t* list = NULL; + std::list dead_up_idx{}; + std::list list{}; new_ranks_count = 0; killed_ranks_count = 0; if (finalized == 1) { - return 1; + LOG_ERROR("is finalized"); + return KVS_STATUS_FAILURE; } if (applied == 1) { size_t is_wait = 1; size_t is_first_collect = 0; - h->get_value_by_name_key(KVS_UP, KVS_IDX, up_idx_str); + KVS_CHECK_STATUS(h->get_value_by_name_key(KVS_UP, KVS_IDX, up_idx_str), + "failed to get KVS IDx"); - up_idx = safe_strtol(up_idx_str, NULL, 10); + KVS_CHECK_STATUS(safe_strtol(up_idx_str, up_idx), "failed to convert KVS IDx"); if (up_idx == 0) is_first_collect = 1; @@ -84,9 +96,10 @@ int pmi_resizable::PMIR_Update(void) { do { /*Waiting new pods*/ usleep(10000); - h->get_value_by_name_key(KVS_UP, KVS_IDX, up_idx_str); + KVS_CHECK_STATUS(h->get_value_by_name_key(KVS_UP, KVS_IDX, up_idx_str), + "failed to get KVS IDx"); - up_idx = safe_strtol(up_idx_str, NULL, 10); + KVS_CHECK_STATUS(safe_strtol(up_idx_str, up_idx), "failed to convert KVS IDx"); if (prev_idx == (int)up_idx) { count_clean_checks = 0; @@ -101,7 +114,7 @@ int pmi_resizable::PMIR_Update(void) { // while (int_list_is_contained(killed_ranks, root_rank) == 1) { int old_root = root_rank; - h->get_new_root(&root_rank); + KVS_CHECK_STATUS(h->get_new_root(&root_rank), "failed to new root rank"); if (my_rank == root_rank && old_root != root_rank) is_new_root = 1; @@ -114,27 +127,29 @@ int pmi_resizable::PMIR_Update(void) { prev_new_ranks_count = new_ranks_count; prev_killed_ranks_count = killed_ranks_count; - h->get_update_ranks(); + KVS_CHECK_STATUS(h->get_update_ranks(), "failed to update ranks"); if (killed_ranks_count != prev_killed_ranks_count) - rank_list_add(&dead_up_idx, up_idx); + dead_up_idx.push_back(up_idx); } - PMIR_Barrier(); + KVS_CHECK_STATUS(PMIR_Barrier(), "barrier failed"); if (my_rank == root_rank && is_new_root == 0) { up_idx++; if (up_idx > 0 && up_idx > MAX_UP_IDX) up_idx = 1; SET_STR(up_idx_str, INT_STR_SIZE, SIZE_T_TEMPLATE, up_idx); - h->set_value(KVS_UP, KVS_IDX, up_idx_str); - h->up_kvs_new_and_dead(); + KVS_CHECK_STATUS(h->set_value(KVS_UP, KVS_IDX, up_idx_str), + "failed to set KVS IDx"); + KVS_CHECK_STATUS(h->up_kvs_new_and_dead(), "failed to update KVS"); } - PMIR_Barrier(); + KVS_CHECK_STATUS(PMIR_Barrier(), "barrier failed"); if (finalized == 1) { - rank_list_clean(&killed_ranks); - rank_list_clean(&new_ranks); - rank_list_clean(&dead_up_idx); - return 1; + killed_ranks.clear(); + new_ranks.clear(); + dead_up_idx.clear(); + LOG_ERROR("is finalized") + return KVS_STATUS_FAILURE; } is_new_root = 0; @@ -151,7 +166,9 @@ int pmi_resizable::PMIR_Update(void) { if (!is_first_collect || ask_only_framework == 1) answer = call_resize_fn(count_pods - killed_ranks_count + new_ranks_count); else { - if ((int)(h->get_replica_size()) != + size_t replica_size; + KVS_CHECK_STATUS(h->get_replica_size(replica_size), "failed to get replica size"); + if (static_cast(replica_size) != count_pods - killed_ranks_count + new_ranks_count) answer = KVS_RA_WAIT; else @@ -167,60 +184,60 @@ int pmi_resizable::PMIR_Update(void) { break; } case KVS_RA_FINALIZE: { - PMIR_Finalize(); - return 1; + KVS_CHECK_STATUS(PMIR_Finalize(), "failed to finalize"); } default: { - printf("Unknown resize action: %d\n", answer); - PMIR_Finalize(); - return 1; + LOG_ERROR("Unknown resize action: %d\n", answer); + KVS_CHECK_STATUS(PMIR_Finalize(), "failed to finalize"); + return KVS_STATUS_FAILURE; } } listener.set_applied_count(count_applied_changes); } while (is_wait == 1); } else { - listener.send_notification(0, h); - h->wait_accept(); + KVS_CHECK_STATUS(listener.send_notification(0, h), "failed to send notification"); + KVS_CHECK_STATUS(h->wait_accept(), "failed to wait accept"); } - h->get_shift(&list); + h->get_shift(list); count_pods = count_pods - killed_ranks_count + new_ranks_count; - h->update(&list, &dead_up_idx, root_rank); + KVS_CHECK_STATUS(h->update(list, dead_up_idx, root_rank), "failed to update root"); root_rank = 0; - PMIR_Barrier(); - h->up_pods_count(); + KVS_CHECK_STATUS(PMIR_Barrier(), "barrier failed"); + KVS_CHECK_STATUS(h->up_pods_count(), "failed to update pods count"); - rank_list_clean(&killed_ranks); - rank_list_clean(&new_ranks); - rank_list_clean(&dead_up_idx); - shift_list_clean(&list); - return 0; + killed_ranks.clear(); + new_ranks.clear(); + dead_up_idx.clear(); + list.clear(); + return KVS_STATUS_SUCCESS; } -void pmi_resizable::Hard_finilize(int sig) { +kvs_status_t pmi_resizable::hard_finalize(int sig) { char rank_str[INT_STR_SIZE]; SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); - h->set_value(KVS_DEAD_POD, my_hostname, rank_str); + KVS_CHECK_STATUS(h->set_value(KVS_DEAD_POD, my_hostname, rank_str), "failed to set dead rank"); - listener.send_notification(sig, h); + KVS_CHECK_STATUS(listener.send_notification(sig, h), "failed to send notification"); extreme_finalize = 1; - PMIR_Finalize(); + KVS_CHECK_STATUS(PMIR_Finalize(), "failed to finalize"); if (old_act.sa_handler != NULL) old_act.sa_handler(sig); + + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_Main_Addr_Reserve(char* main_addr) { - h->main_server_address_reserve(main_addr); - return 0; +kvs_status_t pmi_resizable::PMIR_Main_Addr_Reserve(char* main_addr) { + return h->main_server_address_reserve(main_addr); } -int pmi_resizable::PMIR_Init(const char* main_addr) { +kvs_status_t pmi_resizable::PMIR_Init(const char* main_addr) { struct sigaction act; FILE* fp; finalized = 0; @@ -240,34 +257,34 @@ int pmi_resizable::PMIR_Init(const char* main_addr) { "-%d", getpid()); - if (h->init(main_addr)) { - return 1; - } + KVS_CHECK_STATUS(h->init(main_addr), "failed to init"); - h->reg_rank(); + KVS_CHECK_STATUS(h->reg_rank(), "failed to rank register"); - h->up_pods_count(); + KVS_CHECK_STATUS(h->up_pods_count(), "failed to update pods count"); // TODO: rework it for multi kvs pmi_object = this; memset(&act, 0, sizeof(act)); - act.sa_handler = &Call_Hard_finilize; + act.sa_handler = &call_hard_finalize; act.sa_flags = 0; sigaction(SIGTERM, &act, &old_act); - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_Finalize(void) { +kvs_status_t pmi_resizable::PMIR_Finalize(void) { char kvs_name[MAX_KVS_NAME_LENGTH]; char kvs_key[MAX_KVS_KEY_LENGTH]; char kvs_val[MAX_KVS_VAL_LENGTH]; char rank_str[INT_STR_SIZE]; - if (finalized) - return 0; + if (finalized) { + return KVS_STATUS_SUCCESS; + } - if (my_rank == 0) - PMIR_Barrier(); + if (my_rank == 0) { + KVS_CHECK_STATUS(PMIR_Barrier(), "barrier failed"); + } finalized = 1; @@ -275,101 +292,106 @@ int pmi_resizable::PMIR_Finalize(void) { SET_STR(rank_str, INT_STR_SIZE, RANK_TEMPLATE, my_rank); - h->remove_name_key(KVS_POD_NUM, rank_str); + KVS_CHECK_STATUS(h->remove_name_key(KVS_POD_NUM, rank_str), "failed to remove rank"); while (cut_head(kvs_name, kvs_key, kvs_val, ST_CLIENT)) { - h->remove_name_key(kvs_name, kvs_key); + KVS_CHECK_STATUS(h->remove_name_key(kvs_name, kvs_key), "failed to remove info"); } if (my_rank == 0 && extreme_finalize != 1) { - h->remove_name_key(KVS_UP, KVS_IDX); + KVS_CHECK_STATUS(h->remove_name_key(KVS_UP, KVS_IDX), "failed to remove IDx"); } - h->remove_name_key(KVS_BARRIER, my_hostname); + KVS_CHECK_STATUS(h->remove_name_key(KVS_BARRIER, my_hostname), "failed to remove barrier info"); - h->finalize(); + KVS_CHECK_STATUS(h->finalize(), "failed to finalize"); - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_Barrier(void) { +kvs_status_t pmi_resizable::PMIR_Barrier(void) { size_t min_barrier_num; char barrier_num_str[INT_STR_SIZE]; if (finalized) - return 0; + return KVS_STATUS_SUCCESS; SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num); - h->set_value(KVS_BARRIER, my_hostname, barrier_num_str); + KVS_CHECK_STATUS(h->set_value(KVS_BARRIER, my_hostname, barrier_num_str), + "failed to set barrier info"); - min_barrier_num = h->get_barrier_idx(); + KVS_CHECK_STATUS(h->get_barrier_idx(min_barrier_num), "failed to get barrier IDx"); while (min_barrier_num != barrier_num && finalized != 1) { - min_barrier_num = h->get_barrier_idx(); + KVS_CHECK_STATUS(h->get_barrier_idx(min_barrier_num), "failed to get barrier IDx"); } barrier_num++; if (barrier_num > BARRIER_NUM_MAX) barrier_num = 0; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_Get_size(int* size) { +kvs_status_t pmi_resizable::PMIR_Get_size(int* size) { *size = count_pods; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_Get_rank(int* rank) { +kvs_status_t pmi_resizable::PMIR_Get_rank(int* rank) { *rank = my_rank; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Get_my_name(char* kvs_name, size_t length) { +kvs_status_t pmi_resizable::PMIR_KVS_Get_my_name(char* kvs_name, size_t length) { kvs_str_copy(kvs_name, KVS_NAME, length); - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Get_name_length_max(size_t* length) { +kvs_status_t pmi_resizable::PMIR_KVS_Get_name_length_max(size_t* length) { *length = MAX_KVS_NAME_LENGTH; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Get_key_length_max(size_t* length) { +kvs_status_t pmi_resizable::PMIR_KVS_Get_key_length_max(size_t* length) { *length = MAX_KVS_KEY_LENGTH; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Get_value_length_max(size_t* length) { +kvs_status_t pmi_resizable::PMIR_KVS_Get_value_length_max(size_t* length) { *length = MAX_KVS_VAL_LENGTH; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Commit(const char* kvs_name) { +kvs_status_t pmi_resizable::PMIR_KVS_Commit(const char* kvs_name) { (void)kvs_name; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Put(const char* kvs_name, const char* key, const char* value) { +kvs_status_t pmi_resizable::PMIR_KVS_Put(const char* kvs_name, const char* key, const char* value) { put_key(kvs_name, key, value, ST_CLIENT); - h->set_value(kvs_name, key, value); - return 0; + KVS_CHECK_STATUS(h->set_value(kvs_name, key, value), "failed to set value"); + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_KVS_Get(const char* kvs_name, const char* key, char* value, size_t length) { +kvs_status_t pmi_resizable::PMIR_KVS_Get(const char* kvs_name, + const char* key, + char* value, + size_t length) { (void)length; - while (h->get_value_by_name_key(kvs_name, key, value) == 0) { - } + do { + KVS_CHECK_STATUS(h->get_value_by_name_key(kvs_name, key, value), "failed to get value"); + } while (strlen(value) == 0); - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_set_resize_function(pmir_resize_fn_t resize_fn) { +kvs_status_t pmi_resizable::PMIR_set_resize_function(pmir_resize_fn_t resize_fn) { resize_function = resize_fn; - return 0; + return KVS_STATUS_SUCCESS; } -int pmi_resizable::PMIR_Wait_notification(void) { +kvs_status_t pmi_resizable::PMIR_Wait_notification(void) { return listener.run_listener(h); } @@ -384,11 +406,15 @@ int pmi_resizable::get_size() { size_t pmi_resizable::get_local_thread_idx() { return 0; } -size_t pmi_resizable::get_local_kvs_id() { - return 0; +atl_status_t pmi_resizable::get_local_kvs_id(size_t& res) { + res = 0; + return ATL_STATUS_SUCCESS; +} +atl_status_t pmi_resizable::set_local_kvs_id(size_t local_kvs_id) { + return ATL_STATUS_SUCCESS; } -void pmi_resizable::set_local_kvs_id(size_t local_kvs_id) {} pmi_resizable::~pmi_resizable() { - if (!is_finalized) - pmrt_finalize(); + if (!is_finalized) { + CCL_THROW_IF_NOT(pmrt_finalize(), "pmi finalize failed"); + } } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h deleted file mode 100644 index b9ef14a62..000000000 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/resizable_pmi.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#ifndef PMIR_H_INCLUDED -#define PMIR_H_INCLUDED - -#include - -#ifdef __cplusplus -extern "C" { -#endif -#define PMIR_API __attribute__((visibility("default"))) - -#define PMIR_SUCCESS 0 -#define PMIR_FAIL -1 -#define PMIR_ERR_INIT 1 -#define PMIR_ERR_NOMEM 2 -#define PMIR_ERR_INVALID_ARG 3 -#define PMIR_ERR_INVALID_KEY 4 -#define PMIR_ERR_INVALID_KEY_LENGTH 5 -#define PMIR_ERR_INVALID_VAL 6 -#define PMIR_ERR_INVALID_VAL_LENGTH 7 -#define PMIR_ERR_INVALID_LENGTH 8 -#define PMIR_ERR_INVALID_NUM_ARGS 9 -#define PMIR_ERR_INVALID_ARGS 10 -#define PMIR_ERR_INVALID_NUM_PARSED 11 -#define PMIR_ERR_INVALID_KEYVALP 12 -#define PMIR_ERR_INVALID_SIZE 13 - -typedef enum { - KVS_RA_WAIT = 0, - KVS_RA_RUN = 1, - KVS_RA_FINALIZE = 2, -} kvs_resize_action_t; -typedef kvs_resize_action_t (*pmir_resize_fn_t)(int comm_size); - -int PMIR_API PMIR_Main_Addr_Reserve(char* main_addr); - -int PMIR_API PMIR_Init(const char* main_addr); - -int PMIR_API PMIR_Finalize(void); - -int PMIR_API PMIR_Get_size(int* size); - -int PMIR_API PMIR_Get_rank(int* rank); - -int PMIR_API PMIR_KVS_Get_my_name(char* kvs_name, size_t length); - -int PMIR_API PMIR_KVS_Get_name_length_max(size_t* length); - -int PMIR_API PMIR_Barrier(void); - -int PMIR_API PMIR_Update(void); - -int PMIR_API PMIR_KVS_Get_key_length_max(size_t* length); - -int PMIR_API PMIR_KVS_Get_value_length_max(size_t* length); - -int PMIR_API PMIR_KVS_Put(const char* kvs_name, const char* key, const char* value); - -int PMIR_API PMIR_KVS_Commit(const char* kvs_name); - -int PMIR_API PMIR_KVS_Get(const char* kvs_name, const char* key, char* value, size_t length); - -int PMIR_API PMIR_set_resize_function(pmir_resize_fn_t resize_fn); - -int PMIR_API PMIR_Wait_notification(void); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp deleted file mode 100644 index 30fe48722..000000000 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include -#include - -#include "shift_list.hpp" - -void shift_list_clean(shift_list_t** list) { - shift_list_t* cur_list = (*list); - shift_list_t* node_to_remove; - while (cur_list != NULL) { - node_to_remove = cur_list; - cur_list = cur_list->next; - free(node_to_remove); - } - (*list) = NULL; -} - -void shift_list_add(shift_list_t** list, int old_rank, int new_rank, change_type_t type) { - shift_list_t* cur_list; - if ((*list) == NULL) { - (*list) = (shift_list_t*)malloc(sizeof(shift_list_t)); - if ((*list) == NULL) { - printf("Memory allocation failed\n"); - return; - } - cur_list = (*list); - } - else { - cur_list = (*list); - while (cur_list->next != NULL) - cur_list = cur_list->next; - cur_list->next = (shift_list_t*)malloc(sizeof(shift_list_t)); - cur_list = cur_list->next; - } - cur_list->shift.old_rank = old_rank; - cur_list->shift.new_rank = new_rank; - cur_list->shift.type = type; - cur_list->next = NULL; -} diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp index 59f3d6a5e..876cc1e06 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable/shift_list.hpp @@ -13,12 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef SHIFT_LIST_H_INCLUDED -#define SHIFT_LIST_H_INCLUDED - -#ifdef __cplusplus -extern "C" { -#endif +#pragma once typedef enum change_type { CH_T_SHIFT = 0, CH_T_DEAD = 1, @@ -31,17 +26,3 @@ typedef struct shift_rank { int new_rank; change_type_t type; } shift_rank_t; - -typedef struct shift_list { - shift_rank_t shift; - struct shift_list* next; -} shift_list_t; - -void shift_list_clean(shift_list_t** list); - -void shift_list_add(shift_list_t** list, int old_rank, int new_rank, change_type_t type); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c deleted file mode 100644 index 4cfc14dc3..000000000 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_rt.c +++ /dev/null @@ -1,279 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "pm_rt_codec.h" - -#include -#include -#include - -#include "pmi_resizable/resizable_pmi.h" - -#include "pm_rt.h" - -#define RESIZABLE_PMI_RT_KEY_FORMAT "%s-%d" - -typedef struct resizable_pm_rt_context { - pm_rt_desc_t pmrt_desc; - struct { - size_t initialized; - size_t ref_cnt; - size_t max_keylen; - size_t max_vallen; - char *key_storage; - char *val_storage; - char *kvsname; - } resizablert_main; -} resizable_pm_context_t; - -/* Ensures that this is allocated/initialized only once per process */ -static resizable_pm_context_t resizable_ctx_singleton; - -static void resizable_pmirt_finalize(pm_rt_desc_t *pmrt_desc) { - resizable_pm_context_t *ctx = container_of(pmrt_desc, resizable_pm_context_t, pmrt_desc); - if (!ctx->resizablert_main.initialized) - return; - - if (--ctx->resizablert_main.ref_cnt) - return; - - free(ctx->resizablert_main.kvsname); - free(ctx->resizablert_main.key_storage); - free(ctx->resizablert_main.val_storage); - - PMIR_Finalize(); - - memset(ctx, 0, sizeof(*ctx)); -} - -static void resizable_pmirt_barrier(pm_rt_desc_t *pmrt_desc) { - resizable_pm_context_t *ctx = container_of(pmrt_desc, resizable_pm_context_t, pmrt_desc); - - if (!ctx->resizablert_main.initialized) - return; - - PMIR_Barrier(); -} - -static atl_status_t resizable_pmirt_kvs_put(pm_rt_desc_t *pmrt_desc, - char *kvs_key, - int proc_idx, - const void *kvs_val, - size_t kvs_val_len) { - int ret; - resizable_pm_context_t *ctx = container_of(pmrt_desc, resizable_pm_context_t, pmrt_desc); - - if (!ctx->resizablert_main.initialized) - return ATL_STATUS_FAILURE; - - if (kvs_val_len > ctx->resizablert_main.max_vallen) - return ATL_STATUS_FAILURE; - - ret = snprintf(ctx->resizablert_main.key_storage, - ctx->resizablert_main.max_keylen - 1, - RESIZABLE_PMI_RT_KEY_FORMAT, - kvs_key, - proc_idx); - if (ret < 0) - return ATL_STATUS_FAILURE; - - ret = encode( - kvs_val, kvs_val_len, ctx->resizablert_main.val_storage, ctx->resizablert_main.max_vallen); - if (ret) - return ATL_STATUS_FAILURE; - - ret = PMIR_KVS_Put(ctx->resizablert_main.kvsname, - ctx->resizablert_main.key_storage, - ctx->resizablert_main.val_storage); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - ret = PMIR_KVS_Commit(ctx->resizablert_main.kvsname); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - return ATL_STATUS_SUCCESS; -} - -static atl_status_t resizable_pmirt_kvs_get(pm_rt_desc_t *pmrt_desc, - char *kvs_key, - int proc_idx, - void *kvs_val, - size_t kvs_val_len) { - int ret; - resizable_pm_context_t *ctx = container_of(pmrt_desc, resizable_pm_context_t, pmrt_desc); - - if (!ctx->resizablert_main.initialized) - return ATL_STATUS_FAILURE; - - ret = snprintf(ctx->resizablert_main.key_storage, - ctx->resizablert_main.max_keylen - 1, - RESIZABLE_PMI_RT_KEY_FORMAT, - kvs_key, - proc_idx); - if (ret < 0) - return ATL_STATUS_FAILURE; - - ret = PMIR_KVS_Get(ctx->resizablert_main.kvsname, - ctx->resizablert_main.key_storage, - ctx->resizablert_main.val_storage, - ctx->resizablert_main.max_vallen); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - ret = decode(ctx->resizablert_main.val_storage, kvs_val, kvs_val_len); - if (ret) - return ATL_STATUS_FAILURE; - - return ATL_STATUS_SUCCESS; -} - -static atl_status_t resizable_pmirt_update(int *proc_idx, int *proc_count) { - int ret; - ret = PMIR_Update(); - if (ret != PMIR_SUCCESS) - goto err_resizable; - - ret = PMIR_Get_size(proc_count); - if (ret != PMIR_SUCCESS) - goto err_resizable; - - ret = PMIR_Get_rank(proc_idx); - if (ret != PMIR_SUCCESS) - goto err_resizable; - - return ATL_STATUS_SUCCESS; - -err_resizable: - PMIR_Finalize(); - return ATL_STATUS_FAILURE; -} - -atl_status_t resizable_pmirt_wait_notification() { - int ret; - - ret = PMIR_Wait_notification(); - - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - return ATL_STATUS_SUCCESS; -} - -pm_rt_ops_t resizable_ops = { - .finalize = resizable_pmirt_finalize, - .barrier = resizable_pmirt_barrier, - .update = resizable_pmirt_update, - .wait_notification = resizable_pmirt_wait_notification, -}; - -pm_rt_kvs_ops_t resizable_kvs_ops = { - .put = resizable_pmirt_kvs_put, - .get = resizable_pmirt_kvs_get, -}; - -atl_status_t resizable_pmirt_init(int *proc_idx, - int *proc_count, - pm_rt_desc_t **pmrt_desc, - const char *main_addr) { - int ret; - size_t max_kvsnamelen; - - if (resizable_ctx_singleton.resizablert_main.initialized) { - PMIR_Get_size(proc_idx); - PMIR_Get_rank(proc_count); - *pmrt_desc = &resizable_ctx_singleton.pmrt_desc; - resizable_ctx_singleton.resizablert_main.ref_cnt++; - return ATL_STATUS_SUCCESS; - } - - ret = PMIR_Init(main_addr); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - ret = PMIR_Update(); - if (ret != PMIR_SUCCESS) - return ATL_STATUS_FAILURE; - - ret = PMIR_Get_size(proc_count); - if (ret != PMIR_SUCCESS) - goto err_resizable; - ret = PMIR_Get_rank(proc_idx); - if (ret != PMIR_SUCCESS) - goto err_resizable; - - ret = PMIR_KVS_Get_name_length_max(&max_kvsnamelen); - if (ret != PMIR_SUCCESS) - goto err_resizable; - - resizable_ctx_singleton.resizablert_main.kvsname = calloc(1, max_kvsnamelen); - if (!resizable_ctx_singleton.resizablert_main.kvsname) - goto err_resizable; - - ret = PMIR_KVS_Get_my_name(resizable_ctx_singleton.resizablert_main.kvsname, max_kvsnamelen); - if (ret != PMIR_SUCCESS) - goto err_alloc_key; - - ret = PMIR_KVS_Get_key_length_max(&resizable_ctx_singleton.resizablert_main.max_keylen); - if (ret != PMIR_SUCCESS) - goto err_alloc_key; - - resizable_ctx_singleton.resizablert_main.key_storage = - (char *)calloc(1, resizable_ctx_singleton.resizablert_main.max_keylen); - if (!resizable_ctx_singleton.resizablert_main.key_storage) - goto err_alloc_key; - - ret = PMIR_KVS_Get_value_length_max(&resizable_ctx_singleton.resizablert_main.max_vallen); - if (ret != PMIR_SUCCESS) - goto err_alloc_val; - - resizable_ctx_singleton.resizablert_main.val_storage = - (char *)calloc(1, resizable_ctx_singleton.resizablert_main.max_vallen); - if (!resizable_ctx_singleton.resizablert_main.val_storage) - goto err_alloc_val; - - resizable_ctx_singleton.resizablert_main.initialized = 1; - resizable_ctx_singleton.resizablert_main.ref_cnt = 1; - resizable_ctx_singleton.pmrt_desc.ops = &resizable_ops; - resizable_ctx_singleton.pmrt_desc.kvs_ops = &resizable_kvs_ops; - *pmrt_desc = &resizable_ctx_singleton.pmrt_desc; - - return ATL_STATUS_SUCCESS; -err_alloc_val: - free(resizable_ctx_singleton.resizablert_main.key_storage); -err_alloc_key: - free(resizable_ctx_singleton.resizablert_main.kvsname); -err_resizable: - PMIR_Finalize(); - return ATL_STATUS_FAILURE; -} - -atl_status_t resizable_pmirt_main_addr_reserve(char *main_addr) { - int ret = PMIR_Main_Addr_Reserve(main_addr); - - if (ret) - return ATL_STATUS_FAILURE; - - return ATL_STATUS_SUCCESS; -} - -atl_status_t resizable_pmirt_set_resize_function(atl_resize_fn_t resize_fn) { - int ret = PMIR_set_resize_function((pmir_resize_fn_t)resize_fn); - - if (ret) - return ATL_STATUS_FAILURE; - - return ATL_STATUS_SUCCESS; -} diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp index 19c562854..0d8815bfd 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.cpp @@ -35,17 +35,17 @@ pmi_resizable_simple::pmi_resizable_simple(int size, const char* main_addr) : total_rank_count(size), ranks(ranks), - k(k) { + k(k), + main_addr(main_addr) { max_keylen = MAX_KVS_KEY_LENGTH; max_vallen = MAX_KVS_VAL_LENGTH; - pmrt_init(main_addr); } int pmi_resizable_simple::is_pm_resize_enabled() { return 0; } -atl_status_t pmi_resizable_simple::pmrt_init(const char* main_addr) { +atl_status_t pmi_resizable_simple::pmrt_init() { (void)main_addr; char* kvs_get_timeout_str = getenv("CCL_KVS_GET_TIMEOUT"); @@ -55,31 +55,36 @@ atl_status_t pmi_resizable_simple::pmrt_init(const char* main_addr) { local_id = 0; val_storage = (char*)calloc(1, max_vallen); - if (!val_storage) + if (!val_storage) { + LOG_ERROR("mem alloc failed"); return ATL_STATUS_FAILURE; + } /*TODO: add sort, ranks should increase continiusly*/ if (ranks[0] == 0) { - size_t tmp_local_id = get_local_kvs_id(); + size_t tmp_local_id; + ATL_CHECK_STATUS(get_local_kvs_id(tmp_local_id), "failed to get local id"); tmp_local_id++; - set_local_kvs_id(tmp_local_id); + ATL_CHECK_STATUS(set_local_kvs_id(tmp_local_id), "failed to set local id"); } - make_requested_info(); + + ATL_CHECK_STATUS(make_requested_info(), "failed to make requested info"); /* extension */ // make_map_requested2global(); /**/ return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple::make_requested_info() { - register_first_rank_idx_and_rank_count(); - assign_thread_idx_and_fill_ranks_per_thread_map(); +atl_status_t pmi_resizable_simple::make_requested_info() { + ATL_CHECK_STATUS(register_first_rank_idx_and_rank_count(), "failed to register ranks"); + ATL_CHECK_STATUS(assign_thread_idx_and_fill_ranks_per_thread_map(), "failed to fill map"); - local_id = get_local_kvs_id(); - register_my_proc_name(); - get_my_proc_idx_and_proc_count(); + ATL_CHECK_STATUS(get_local_kvs_id(local_id), "failed to get local id"); + ATL_CHECK_STATUS(register_my_proc_name(), "failed to register proc name"); + ATL_CHECK_STATUS(get_my_proc_idx_and_proc_count(), "failed to get proc idx"); calculate_local_thread_idx(); - remove_initial_data(); - pmrt_barrier_full(); + ATL_CHECK_STATUS(remove_initial_data(), "failed to remove initial data"); + ATL_CHECK_STATUS(pmrt_barrier_full(), "full barrier failed"); + return ATL_STATUS_SUCCESS; } atl_status_t pmi_resizable_simple::pmrt_main_addr_reserve(char* main_addr) { @@ -98,14 +103,13 @@ atl_status_t pmi_resizable_simple::pmrt_wait_notification() { return ATL_STATUS_UNSUPPORTED; } -void pmi_resizable_simple::pmrt_finalize() { +atl_status_t pmi_resizable_simple::pmrt_finalize() { is_finalized = true; free(val_storage); if (getenv("CCL_PMI_FORCE_FINALIZE")) { - printf("skip pmi_resizable_simple::pmrt_finalize\n"); - fflush(stdout); - return; + LOG_WARN("skip pmi_resizable_simple::pmrt_finalize\n"); + return ATL_STATUS_SUCCESS; } char kvs_name[MAX_KVS_NAME_LENGTH]; @@ -113,63 +117,74 @@ void pmi_resizable_simple::pmrt_finalize() { char kvs_val[MAX_KVS_VAL_LENGTH]; while (cut_head(kvs_name, kvs_key, kvs_val, ST_CLIENT)) { - k->kvs_remove_name_key(kvs_name, kvs_key); + KVS_2_ATL_CHECK_STATUS(k->kvs_remove_name_key(kvs_name, kvs_key), "failed to remove info"); } + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple::pmrt_barrier() { +atl_status_t pmi_resizable_simple::pmrt_barrier() { size_t min_barrier_num; char barrier_num_str[INT_STR_SIZE]; - SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num); + ATL_SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num); - kvs_set_value(KVS_BARRIER, std::to_string(assigned_proc_idx).c_str(), barrier_num_str); + ATL_CHECK_STATUS( + kvs_set_value(KVS_BARRIER, std::to_string(assigned_proc_idx).c_str(), barrier_num_str), + "failed to set barrier num"); - min_barrier_num = get_barrier_idx(); - while (min_barrier_num != barrier_num) { - min_barrier_num = get_barrier_idx(); - } + do { + ATL_CHECK_STATUS(get_barrier_idx(min_barrier_num), "failed to get barrier num"); + } while (min_barrier_num != barrier_num); barrier_num++; if (barrier_num > BARRIER_NUM_MAX) barrier_num = 0; + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple::pmrt_barrier_full() { +atl_status_t pmi_resizable_simple::pmrt_barrier_full() { size_t min_barrier_num; char barrier_num_str[INT_STR_SIZE]; - SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num_full); + ATL_SET_STR(barrier_num_str, INT_STR_SIZE, SIZE_T_TEMPLATE, barrier_num_full); - kvs_set_value(KVS_BARRIER_FULL, std::to_string(assigned_thread_idx).c_str(), barrier_num_str); + ATL_CHECK_STATUS( + kvs_set_value( + KVS_BARRIER_FULL, std::to_string(assigned_thread_idx).c_str(), barrier_num_str), + "failed to set barrier num"); - min_barrier_num = get_barrier_full_idx(); + ATL_CHECK_STATUS(get_barrier_full_idx(min_barrier_num), "failed to get barrier num"); while (min_barrier_num != barrier_num) { - min_barrier_num = get_barrier_idx(); + ATL_CHECK_STATUS(get_barrier_idx(min_barrier_num), "failed to get barrier num"); } barrier_num_full++; if (barrier_num_full > BARRIER_NUM_MAX) barrier_num_full = 0; + return ATL_STATUS_SUCCESS; } -size_t pmi_resizable_simple::get_barrier_full_idx() { +atl_status_t pmi_resizable_simple::get_barrier_full_idx(size_t& res) { + res = 0; size_t thread_count = ranks_per_thread_map.size(); - kvs_get_value(KVS_BARRIER_FULL, std::to_string(0).c_str(), val_storage); + ATL_CHECK_STATUS(kvs_get_value(KVS_BARRIER_FULL, std::to_string(0).c_str(), val_storage), + "failed to get barrier idx"); size_t min_barrier_idx = atoi(val_storage); size_t barrier_idx; for (size_t i = 1; i < thread_count; i++) { - kvs_get_value(KVS_BARRIER_FULL, std::to_string(i).c_str(), val_storage); + ATL_CHECK_STATUS(kvs_get_value(KVS_BARRIER_FULL, std::to_string(i).c_str(), val_storage), + "failed to get barrier idx"); barrier_idx = atoi(val_storage); if (min_barrier_idx > barrier_idx) min_barrier_idx = barrier_idx; } - - return min_barrier_idx; + res = min_barrier_idx; + return ATL_STATUS_SUCCESS; } + atl_status_t pmi_resizable_simple::pmrt_kvs_put(char* kvs_key, int proc_idx, const void* kvs_val, @@ -180,14 +195,18 @@ atl_status_t pmi_resizable_simple::pmrt_kvs_put(char* kvs_key, return ATL_STATUS_FAILURE; ret = snprintf(key_storage, max_keylen - 1, RESIZABLE_PMI_RT_KEY_FORMAT, kvs_key, proc_idx); - if (ret < 0) + if (ret < 0) { + LOG_ERROR("sprintf failed"); return ATL_STATUS_FAILURE; + } ret = encode(kvs_val, kvs_val_len, val_storage, max_vallen); - if (ret) + if (ret) { + LOG_ERROR("encode failed"); return ATL_STATUS_FAILURE; + } - kvs_set_value(KVS_NAME, key_storage, val_storage); + ATL_CHECK_STATUS(kvs_set_value(KVS_NAME, key_storage, val_storage), "failed to set val"); return ATL_STATUS_SUCCESS; } @@ -200,14 +219,18 @@ atl_status_t pmi_resizable_simple::pmrt_kvs_get(char* kvs_key, char key_storage[max_keylen]; ret = snprintf(key_storage, max_keylen - 1, RESIZABLE_PMI_RT_KEY_FORMAT, kvs_key, proc_idx); - if (ret < 0) + if (ret < 0) { + LOG_ERROR("sprintf failed"); return ATL_STATUS_FAILURE; + } - kvs_get_value(KVS_NAME, key_storage, val_storage); + ATL_CHECK_STATUS(kvs_get_value(KVS_NAME, key_storage, val_storage), "failed to get val"); ret = decode(val_storage, kvs_val, kvs_val_len); - if (ret) + if (ret) { + LOG_ERROR("encode failed"); return ATL_STATUS_FAILURE; + } return ATL_STATUS_SUCCESS; } @@ -224,100 +247,118 @@ size_t pmi_resizable_simple::get_local_thread_idx() { return local_thread_idx; } -int pmi_resizable_simple::kvs_set_value(const char* kvs_name, const char* key, const char* value) { +atl_status_t pmi_resizable_simple::kvs_set_value(const char* kvs_name, + const char* key, + const char* value) { std::string result_kvs_name = std::string(kvs_name) + std::to_string(local_id); put_key(result_kvs_name.c_str(), key, value, ST_CLIENT); - return k->kvs_set_value(result_kvs_name.c_str(), key, value); + return (k->kvs_set_value(result_kvs_name.c_str(), key, value) == KVS_STATUS_SUCCESS) + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } -int pmi_resizable_simple::kvs_get_value(const char* kvs_name, const char* key, char* value) { +atl_status_t pmi_resizable_simple::kvs_get_value(const char* kvs_name, + const char* key, + char* value) { std::string result_kvs_name = std::string(kvs_name) + std::to_string(local_id); time_t start_time = time(NULL); size_t kvs_get_time = 0; - while (k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value) == 0 && - kvs_get_time < kvs_get_timeout) { + do { + KVS_2_ATL_CHECK_STATUS(k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value), + "failed to get value"); kvs_get_time = time(NULL) - start_time; - } + } while (strlen(value) == 0 && kvs_get_time < kvs_get_timeout); if (kvs_get_time >= kvs_get_timeout) { - printf("KVS get error: timeout limit (%zu > %zu), prefix: %s, key %s\n", - kvs_get_time, - kvs_get_timeout, - result_kvs_name.c_str(), - key); - exit(1); + LOG_ERROR("KVS get error: timeout limit (%zu > %zu), prefix: %s, key %s\n", + kvs_get_time, + kvs_get_timeout, + result_kvs_name.c_str(), + key); + return ATL_STATUS_FAILURE; } return ATL_STATUS_SUCCESS; } -int pmi_resizable_simple::kvs_iget_value(const char* kvs_name, const char* key, char* value) { +atl_status_t pmi_resizable_simple::kvs_iget_value(const char* kvs_name, + const char* key, + char* value) { std::string result_kvs_name = std::string(kvs_name) + std::to_string(local_id); - return k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value); + return k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value) == KVS_STATUS_SUCCESS + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } -size_t pmi_resizable_simple::get_barrier_idx() { +atl_status_t pmi_resizable_simple::get_barrier_idx(size_t& barrier_num_out) { size_t proc_count = threads_per_proc.size(); + barrier_num_out = 0; - kvs_get_value(KVS_BARRIER, std::to_string(0).c_str(), val_storage); + ATL_CHECK_STATUS(kvs_get_value(KVS_BARRIER, std::to_string(0).c_str(), val_storage), + "failed to get barrier"); size_t min_barrier_idx = atoi(val_storage); size_t barrier_idx; for (size_t i = 1; i < proc_count; i++) { - kvs_get_value(KVS_BARRIER, std::to_string(i).c_str(), val_storage); - + ATL_CHECK_STATUS(kvs_get_value(KVS_BARRIER, std::to_string(i).c_str(), val_storage), + "failed to get barrier"); barrier_idx = atoi(val_storage); if (min_barrier_idx > barrier_idx) min_barrier_idx = barrier_idx; } - return min_barrier_idx; + barrier_num_out = min_barrier_idx; + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple::register_first_rank_idx_and_rank_count() { - kvs_set_value( +atl_status_t pmi_resizable_simple::register_first_rank_idx_and_rank_count() { + return kvs_set_value( RANKS_PER_THREAD, std::to_string(ranks[0]).c_str(), std::to_string(ranks.size()).c_str()); } -void pmi_resizable_simple::assign_thread_idx_and_fill_ranks_per_thread_map() { +atl_status_t pmi_resizable_simple::assign_thread_idx_and_fill_ranks_per_thread_map() { int rank_count = 0; int ranks_per_thread; while (rank_count < total_rank_count) { if (rank_count == ranks[0]) { assigned_thread_idx = ranks_per_thread_map.size(); } - kvs_get_value(RANKS_PER_THREAD, std::to_string(rank_count).c_str(), val_storage); + ATL_CHECK_STATUS( + kvs_get_value(RANKS_PER_THREAD, std::to_string(rank_count).c_str(), val_storage), + "failed to get ranks"); - ranks_per_thread = safe_strtol(val_storage, NULL, 10); + ranks_per_thread = std::atoi(val_storage); ranks_per_thread_map.push_back(ranks_per_thread); rank_count += ranks_per_thread; } + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple::register_my_proc_name() { +atl_status_t pmi_resizable_simple::register_my_proc_name() { int my_pid = getpid(); const int hostname_len = 1024; char hostname[hostname_len]; int ret = gethostname(hostname, hostname_len); if (ret) { - printf("gethostname error: %s\n", strerror(errno)); - exit(EXIT_FAILURE); + LOG_ERROR("gethostname error: %s\n", strerror(errno)); + return ATL_STATUS_FAILURE; } my_proccess_name = std::string(hostname) + std::to_string(my_pid); - kvs_set_value( + return kvs_set_value( PROCESS_THREAD_NAME, std::to_string(assigned_thread_idx).c_str(), my_proccess_name.c_str()); } -void pmi_resizable_simple::get_my_proc_idx_and_proc_count() { +atl_status_t pmi_resizable_simple::get_my_proc_idx_and_proc_count() { std::map proc_name_to_rank; std::map::iterator it; int rank; for (size_t i = 0; i < ranks_per_thread_map.size(); i++) { - kvs_get_value(PROCESS_THREAD_NAME, std::to_string(i).c_str(), val_storage); + ATL_CHECK_STATUS(kvs_get_value(PROCESS_THREAD_NAME, std::to_string(i).c_str(), val_storage), + "failed to get proc name"); it = proc_name_to_rank.find(val_storage); if (it == proc_name_to_rank.end()) { @@ -325,9 +366,10 @@ void pmi_resizable_simple::get_my_proc_idx_and_proc_count() { if (!my_proccess_name.compare(val_storage)) { assigned_proc_idx = rank; if (assigned_thread_idx == i) { - kvs_set_value(REQUESTED_RANK_TO_NAME, - std::to_string(assigned_proc_idx).c_str(), - my_proccess_name.c_str()); + ATL_CHECK_STATUS(kvs_set_value(REQUESTED_RANK_TO_NAME, + std::to_string(assigned_proc_idx).c_str(), + my_proccess_name.c_str()), + "failed to set proc name"); } } proc_name_to_rank[val_storage] = rank; @@ -337,6 +379,7 @@ void pmi_resizable_simple::get_my_proc_idx_and_proc_count() { threads_per_proc[it->second].push_back(i); } } + return ATL_STATUS_SUCCESS; } void pmi_resizable_simple::calculate_local_thread_idx() { @@ -350,55 +393,78 @@ void pmi_resizable_simple::calculate_local_thread_idx() { } } -void pmi_resizable_simple::make_map_requested2global() { +atl_status_t pmi_resizable_simple::make_map_requested2global() { char global_rank_str[MAX_KVS_VAL_LENGTH]; char process_name[MAX_KVS_VAL_LENGTH]; size_t size = get_size(); requested2global.resize(size); - pmrt_barrier_full(); + ATL_CHECK_STATUS(pmrt_barrier_full(), "make_map_requested2global: full barrier failed"); for (size_t i = 0; i < size; i++) { - kvs_get_value(REQUESTED_RANK_TO_NAME, std::to_string(i).c_str(), process_name); - if (kvs_iget_value(GLOBAL_NAME_TO_RANK, process_name, global_rank_str) == 0) { + ATL_CHECK_STATUS( + kvs_get_value(REQUESTED_RANK_TO_NAME, std::to_string(i).c_str(), process_name), + "make_map_requested2global: failed to get proc name"); + ATL_CHECK_STATUS(kvs_iget_value(GLOBAL_NAME_TO_RANK, process_name, global_rank_str), + "make_map_requested2global: failed to get glob rank"); + if (strlen(global_rank_str) == 0) { if (!my_proccess_name.compare(process_name)) { int free_glob_rank = 0; - while (kvs_iget_value(GLOBAL_RANK_TO_NAME, - std::to_string(free_glob_rank).c_str(), - process_name) != 0) { + ATL_CHECK_STATUS( + kvs_iget_value( + GLOBAL_RANK_TO_NAME, std::to_string(free_glob_rank).c_str(), process_name), + "make_map_requested2global: failed to get proc name"); + while (strlen(process_name) != 0) { free_glob_rank++; + ATL_CHECK_STATUS(kvs_iget_value(GLOBAL_RANK_TO_NAME, + std::to_string(free_glob_rank).c_str(), + process_name), + "make_map_requested2global: failed to get proc name"); } - kvs_set_value(GLOBAL_RANK_TO_NAME, - std::to_string(free_glob_rank).c_str(), - my_proccess_name.c_str()); - kvs_set_value(GLOBAL_NAME_TO_RANK, - my_proccess_name.c_str(), - std::to_string(free_glob_rank).c_str()); + ATL_CHECK_STATUS(kvs_set_value(GLOBAL_RANK_TO_NAME, + std::to_string(free_glob_rank).c_str(), + my_proccess_name.c_str()), + "make_map_requested2global: failed to set proc name"); + ATL_CHECK_STATUS(kvs_set_value(GLOBAL_NAME_TO_RANK, + my_proccess_name.c_str(), + std::to_string(free_glob_rank).c_str()), + "make_map_requested2global: failed to set free rank info"); } - kvs_get_value(GLOBAL_NAME_TO_RANK, process_name, global_rank_str); + ATL_CHECK_STATUS(kvs_get_value(GLOBAL_NAME_TO_RANK, process_name, global_rank_str), + "make_map_requested2global: failed to get rank info"); } requested2global[i] = atoi(global_rank_str); } - pmrt_barrier_full(); + ATL_CHECK_STATUS(pmrt_barrier_full(), "make_map_requested2global: full barrier failed"); + return ATL_STATUS_SUCCESS; } -size_t pmi_resizable_simple::get_local_kvs_id() { +atl_status_t pmi_resizable_simple::get_local_kvs_id(size_t& res) { char local_kvs_id[MAX_KVS_VAL_LENGTH]; + res = 0; /*TODO: change it for collect local_per_rank id, not global*/ - if (k->kvs_get_value_by_name_key(LOCAL_KVS_ID, "ID", local_kvs_id) == 0) - return 0; - return atoi(local_kvs_id); + KVS_2_ATL_CHECK_STATUS(k->kvs_get_value_by_name_key(LOCAL_KVS_ID, "ID", local_kvs_id), + "failed to get local kvs id"); + res = atoi(local_kvs_id); + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple::set_local_kvs_id(size_t local_kvs_id) { +atl_status_t pmi_resizable_simple::set_local_kvs_id(size_t local_kvs_id) { /*TODO: change it for collect local_per_rank id, not global*/ put_key(LOCAL_KVS_ID, "ID", std::to_string(local_kvs_id).c_str(), ST_CLIENT); - k->kvs_set_value(LOCAL_KVS_ID, "ID", std::to_string(local_kvs_id).c_str()); + return (k->kvs_set_value(LOCAL_KVS_ID, "ID", std::to_string(local_kvs_id).c_str()) == + KVS_STATUS_SUCCESS) + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } pmi_resizable_simple::~pmi_resizable_simple() { - if (!is_finalized) - pmrt_finalize(); + if (!is_finalized) { + CCL_THROW_IF_NOT(pmrt_finalize() == ATL_STATUS_SUCCESS, "~pmi_resizable_simple: failed"); + } } -void pmi_resizable_simple::remove_initial_data() { +atl_status_t pmi_resizable_simple::remove_initial_data() { std::string result_kvs_name = std::string(RANKS_PER_THREAD) + std::to_string(0); remove_val(result_kvs_name.c_str(), std::to_string(ranks[0]).c_str(), ST_CLIENT); - k->kvs_remove_name_key(result_kvs_name.c_str(), std::to_string(ranks[0]).c_str()); + return k->kvs_remove_name_key(result_kvs_name.c_str(), std::to_string(ranks[0]).c_str()) == + KVS_STATUS_SUCCESS + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h index 3475bd2d9..8bb255883 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple.h @@ -45,7 +45,7 @@ class pmi_resizable_simple final : public ipmi { pmi_resizable_simple(int total_rank_count, const std::vector& ranks, std::shared_ptr k, - const char* main_addr = nullptr); + const char* main_addr = ""); ~pmi_resizable_simple() override; @@ -59,7 +59,7 @@ class pmi_resizable_simple final : public ipmi { atl_status_t pmrt_wait_notification() override; - void pmrt_barrier() override; + atl_status_t pmrt_barrier() override; atl_status_t pmrt_kvs_put(char* kvs_key, int proc_idx, @@ -77,9 +77,9 @@ class pmi_resizable_simple final : public ipmi { size_t get_local_thread_idx() override; - size_t get_local_kvs_id() override; + atl_status_t get_local_kvs_id(size_t& res) override; - void set_local_kvs_id(size_t local_kvs_id) override; + atl_status_t set_local_kvs_id(size_t local_kvs_id) override; size_t get_threads_per_process() override { return threads_per_proc[assigned_proc_idx].size(); @@ -94,28 +94,29 @@ class pmi_resizable_simple final : public ipmi { return res; } - void pmrt_finalize() override; + atl_status_t pmrt_finalize() override; + + atl_status_t pmrt_init() override; private: bool is_finalized{ false }; - atl_status_t pmrt_init(const char* main_addr = nullptr); - int kvs_set_value(const char* kvs_name, const char* key, const char* value); - int kvs_get_value(const char* kvs_name, const char* key, char* value); - int kvs_iget_value(const char* kvs_name, const char* key, char* value); + atl_status_t kvs_set_value(const char* kvs_name, const char* key, const char* value); + atl_status_t kvs_get_value(const char* kvs_name, const char* key, char* value); + atl_status_t kvs_iget_value(const char* kvs_name, const char* key, char* value); - size_t get_barrier_idx(); - size_t get_barrier_full_idx(); + atl_status_t get_barrier_idx(size_t& barrier_num_out); + atl_status_t get_barrier_full_idx(size_t& res); void calculate_local_thread_idx(); - void register_first_rank_idx_and_rank_count(); - void assign_thread_idx_and_fill_ranks_per_thread_map(); - void register_my_proc_name(); - void get_my_proc_idx_and_proc_count(); - void make_requested_info(); - void remove_initial_data(); - void make_map_requested2global(); - void pmrt_barrier_full(); + atl_status_t register_first_rank_idx_and_rank_count(); + atl_status_t assign_thread_idx_and_fill_ranks_per_thread_map(); + atl_status_t register_my_proc_name(); + atl_status_t get_my_proc_idx_and_proc_count(); + atl_status_t make_requested_info(); + atl_status_t remove_initial_data(); + atl_status_t make_map_requested2global(); + atl_status_t pmrt_barrier_full(); int total_rank_count; int assigned_proc_idx; @@ -127,6 +128,7 @@ class pmi_resizable_simple final : public ipmi { std::vector ranks_per_thread_map; std::map> threads_per_proc; std::shared_ptr k; + std::string main_addr; size_t max_keylen; size_t max_vallen; char* val_storage = nullptr; diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.cpp b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.cpp index 1975d368e..0397f54d4 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.cpp +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.cpp @@ -38,17 +38,17 @@ pmi_resizable_simple_internal::pmi_resizable_simple_internal(int size, const char* main_addr) : total_rank_count(size), ranks(ranks), - k(k) { + k(k), + main_addr(main_addr) { max_keylen = MAX_KVS_KEY_LENGTH; max_vallen = MAX_KVS_VAL_LENGTH; - pmrt_init(main_addr); } int pmi_resizable_simple_internal::is_pm_resize_enabled() { return 0; } -atl_status_t pmi_resizable_simple_internal::pmrt_init(const char* main_addr) { +atl_status_t pmi_resizable_simple_internal::pmrt_init() { (void)main_addr; char* kvs_get_timeout_str = getenv("CCL_KVS_GET_TIMEOUT"); @@ -58,26 +58,29 @@ atl_status_t pmi_resizable_simple_internal::pmrt_init(const char* main_addr) { local_id = 0; val_storage = (char*)calloc(1, max_vallen); - if (!val_storage) + if (!val_storage) { + LOG_ERROR("mem alloc failed"); return ATL_STATUS_FAILURE; - local_id = get_local_kvs_id(); - barrier_full_reg(); + } + ATL_CHECK_STATUS(get_local_kvs_id(local_id), "failed to get local id"); + ATL_CHECK_STATUS(barrier_full_reg(), "failed to full_barrier info register"); - registration(); + ATL_CHECK_STATUS(registration(), "registration failed"); if (ranks[0] == 0) { - size_t tmp_local_id = get_local_kvs_id(); + size_t tmp_local_id; + ATL_CHECK_STATUS(get_local_kvs_id(tmp_local_id), "failed to get local id"); tmp_local_id++; - set_local_kvs_id(tmp_local_id); + ATL_CHECK_STATUS(set_local_kvs_id(tmp_local_id), "failed to set local id"); } if (thread_num == 0) { - barrier_reg(); + ATL_CHECK_STATUS(barrier_reg(), "failed to barrier info register"); } return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple_internal::registration() { +atl_status_t pmi_resizable_simple_internal::registration() { std::string total_local_rank_count_str = std::to_string(total_rank_count); std::string result_kvs_name = std::string(INTERNAL_REGISTRATION) + std::to_string(local_id); memset(val_storage, 0, max_vallen); @@ -88,10 +91,14 @@ void pmi_resizable_simple_internal::registration() { ranks[0], getpid(), gettid()); - k->kvs_set_size( - result_kvs_name.c_str(), result_kvs_name.c_str(), total_local_rank_count_str.c_str()); + KVS_2_ATL_CHECK_STATUS( + k->kvs_set_size( + result_kvs_name.c_str(), result_kvs_name.c_str(), total_local_rank_count_str.c_str()), + "failed to set total rank count"); /*return string: %PROC_COUNT%_%RANK_NUM%_%PROCESS_RANK_COUNT%_%THREADS_COUNT%_%THREAD_NUM% */ - k->kvs_register(result_kvs_name.c_str(), result_kvs_name.c_str(), val_storage); + KVS_2_ATL_CHECK_STATUS( + k->kvs_register(result_kvs_name.c_str(), result_kvs_name.c_str(), val_storage), + "failed to register"); char* proc_count_str = val_storage; char* rank_str = strstr(proc_count_str, "_"); @@ -112,53 +119,63 @@ void pmi_resizable_simple_internal::registration() { proc_rank_count = std::stoi(proc_rank_count_str); threads_count = std::stoi(threads_count_str); thread_num = std::stoi(thread_num_str); + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple_internal::barrier_full_reg() { +atl_status_t pmi_resizable_simple_internal::barrier_full_reg() { std::string empty_line(""); std::string total_local_rank_count_str = std::to_string(total_rank_count) + "_" + std::to_string(ranks.size()); std::string result_kvs_name = std::string(KVS_BARRIER_FULL) + std::to_string(local_id); - k->kvs_barrier_register( - result_kvs_name.c_str(), result_kvs_name.c_str(), total_local_rank_count_str.c_str()); - pmrt_barrier_full(); + KVS_2_ATL_CHECK_STATUS( + k->kvs_barrier_register( + result_kvs_name.c_str(), result_kvs_name.c_str(), total_local_rank_count_str.c_str()), + "registration failed"); + ATL_CHECK_STATUS(pmrt_barrier_full(), "full barrier failed"); + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple_internal::barrier_reg() { +atl_status_t pmi_resizable_simple_internal::barrier_reg() { std::string empty_line(""); std::string proc_count_str = std::to_string(proc_count); std::string result_kvs_name = std::string(KVS_BARRIER) + std::to_string(local_id); - k->kvs_barrier_register( - result_kvs_name.c_str(), result_kvs_name.c_str(), proc_count_str.c_str()); - pmrt_barrier_full(); + KVS_2_ATL_CHECK_STATUS( + k->kvs_barrier_register( + result_kvs_name.c_str(), result_kvs_name.c_str(), proc_count_str.c_str()), + "registration failed"); + ATL_CHECK_STATUS(pmrt_barrier_full(), "full barrier failed"); + return ATL_STATUS_SUCCESS; } atl_status_t pmi_resizable_simple_internal::pmrt_main_addr_reserve(char* main_addr) { + LOG_ERROR("unsupported"); return ATL_STATUS_UNSUPPORTED; } atl_status_t pmi_resizable_simple_internal::pmrt_set_resize_function(atl_resize_fn_t resize_fn) { + LOG_ERROR("unsupported"); return ATL_STATUS_UNSUPPORTED; } atl_status_t pmi_resizable_simple_internal::pmrt_update() { + LOG_ERROR("unsupported"); return ATL_STATUS_UNSUPPORTED; } atl_status_t pmi_resizable_simple_internal::pmrt_wait_notification() { + LOG_ERROR("unsupported"); return ATL_STATUS_UNSUPPORTED; } -void pmi_resizable_simple_internal::pmrt_finalize() { +atl_status_t pmi_resizable_simple_internal::pmrt_finalize() { is_finalized = true; free(val_storage); if (getenv("CCL_PMI_FORCE_FINALIZE")) { - printf("skip pmi_resizable_simple::pmrt_finalize\n"); - fflush(stdout); - return; + LOG_WARN("skip pmi_resizable_simple::pmrt_finalize\n"); + return ATL_STATUS_SUCCESS; } char kvs_name[MAX_KVS_NAME_LENGTH]; @@ -166,22 +183,29 @@ void pmi_resizable_simple_internal::pmrt_finalize() { char kvs_val[MAX_KVS_VAL_LENGTH]; while (cut_head(kvs_name, kvs_key, kvs_val, ST_CLIENT)) { - k->kvs_remove_name_key(kvs_name, kvs_key); + KVS_2_ATL_CHECK_STATUS(k->kvs_remove_name_key(kvs_name, kvs_key), "failed to remove info"); } + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple_internal::pmrt_barrier() { +atl_status_t pmi_resizable_simple_internal::pmrt_barrier() { std::string empty_line(""); std::string result_kvs_name = std::string(KVS_BARRIER) + std::to_string(local_id); - k->kvs_barrier(result_kvs_name.c_str(), result_kvs_name.c_str(), empty_line.c_str()); + return k->kvs_barrier(result_kvs_name.c_str(), result_kvs_name.c_str(), empty_line.c_str()) == + KVS_STATUS_SUCCESS + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } -void pmi_resizable_simple_internal::pmrt_barrier_full() { +atl_status_t pmi_resizable_simple_internal::pmrt_barrier_full() { std::string empty_line(""); std::string result_kvs_name = std::string(KVS_BARRIER_FULL) + std::to_string(local_id); - k->kvs_barrier(result_kvs_name.c_str(), result_kvs_name.c_str(), (empty_line.c_str())); + return k->kvs_barrier(result_kvs_name.c_str(), result_kvs_name.c_str(), (empty_line.c_str())) == + KVS_STATUS_SUCCESS + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } atl_status_t pmi_resizable_simple_internal::pmrt_kvs_put(char* kvs_key, @@ -190,18 +214,24 @@ atl_status_t pmi_resizable_simple_internal::pmrt_kvs_put(char* kvs_key, size_t kvs_val_len) { int ret; char key_storage[max_keylen]; - if (kvs_val_len > max_vallen) + if (kvs_val_len > max_vallen) { + LOG_ERROR("asked len > max len"); return ATL_STATUS_FAILURE; + } ret = snprintf(key_storage, max_keylen - 1, RESIZABLE_PMI_RT_KEY_FORMAT, kvs_key, proc_idx); - if (ret < 0) + if (ret < 0) { + LOG_ERROR("snprintf failed"); return ATL_STATUS_FAILURE; + } ret = encode(kvs_val, kvs_val_len, val_storage, max_vallen); - if (ret) + if (ret) { + LOG_ERROR("encode failed"); return ATL_STATUS_FAILURE; + } - kvs_set_value(KVS_NAME, key_storage, val_storage); + ATL_CHECK_STATUS(kvs_set_value(KVS_NAME, key_storage, val_storage), "failed to set val"); return ATL_STATUS_SUCCESS; } @@ -214,14 +244,18 @@ atl_status_t pmi_resizable_simple_internal::pmrt_kvs_get(char* kvs_key, char key_storage[max_keylen]; ret = snprintf(key_storage, max_keylen - 1, RESIZABLE_PMI_RT_KEY_FORMAT, kvs_key, proc_idx); - if (ret < 0) + if (ret < 0) { + LOG_ERROR("snprintf failed"); return ATL_STATUS_FAILURE; + } - kvs_get_value(KVS_NAME, key_storage, val_storage); + ATL_CHECK_STATUS(kvs_get_value(KVS_NAME, key_storage, val_storage), "failed to get val"); ret = decode(val_storage, kvs_val, kvs_val_len); - if (ret) + if (ret) { + LOG_ERROR("decode failed"); return ATL_STATUS_FAILURE; + } return ATL_STATUS_SUCCESS; } @@ -255,46 +289,53 @@ int pmi_resizable_simple_internal::kvs_set_value(const char* kvs_name, return k->kvs_set_value(result_kvs_name.c_str(), key, value); } -int pmi_resizable_simple_internal::kvs_get_value(const char* kvs_name, - const char* key, - char* value) { +atl_status_t pmi_resizable_simple_internal::kvs_get_value(const char* kvs_name, + const char* key, + char* value) { std::string result_kvs_name = std::string(kvs_name) + std::to_string(local_id); time_t start_time = time(NULL); size_t kvs_get_time = 0; - while (k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value) == 0 && - kvs_get_time < kvs_get_timeout) { + do { + KVS_2_ATL_CHECK_STATUS(k->kvs_get_value_by_name_key(result_kvs_name.c_str(), key, value), + "failed to get value"); kvs_get_time = time(NULL) - start_time; - } + } while (strlen(value) == 0 && kvs_get_time < kvs_get_timeout); if (kvs_get_time >= kvs_get_timeout) { - printf("KVS get error: timeout limit (%zu > %zu), prefix: %s, key %s\n", - kvs_get_time, - kvs_get_timeout, - result_kvs_name.c_str(), - key); - exit(1); + LOG_ERROR("KVS get error: timeout limit (%zu > %zu), prefix: %s, key %s\n", + kvs_get_time, + kvs_get_timeout, + result_kvs_name.c_str(), + key); + return ATL_STATUS_FAILURE; } - return ATL_STATUS_SUCCESS; } -size_t pmi_resizable_simple_internal::get_local_kvs_id() { +atl_status_t pmi_resizable_simple_internal::get_local_kvs_id(size_t& res) { char local_kvs_id[MAX_KVS_VAL_LENGTH]; + res = 0; /*TODO: change it for collect local_per_rank id, not global*/ - if (k->kvs_get_value_by_name_key(LOCAL_KVS_ID, "ID", local_kvs_id) == 0) - return 0; - return atoi(local_kvs_id); + KVS_2_ATL_CHECK_STATUS(k->kvs_get_value_by_name_key(LOCAL_KVS_ID, "ID", local_kvs_id), + "failed to get local kvs id"); + res = atoi(local_kvs_id); + return ATL_STATUS_SUCCESS; } -void pmi_resizable_simple_internal::set_local_kvs_id(size_t local_kvs_id) { +atl_status_t pmi_resizable_simple_internal::set_local_kvs_id(size_t local_kvs_id) { /*TODO: change it for collect local_per_rank id, not global*/ put_key(LOCAL_KVS_ID, "ID", std::to_string(local_kvs_id).c_str(), ST_CLIENT); - k->kvs_set_value(LOCAL_KVS_ID, "ID", std::to_string(local_kvs_id).c_str()); + return k->kvs_set_value(LOCAL_KVS_ID, "ID", std::to_string(local_kvs_id).c_str()) == + KVS_STATUS_SUCCESS + ? ATL_STATUS_SUCCESS + : ATL_STATUS_FAILURE; } pmi_resizable_simple_internal::~pmi_resizable_simple_internal() { - if (!is_finalized) - pmrt_finalize(); + if (!is_finalized) { + CCL_THROW_IF_NOT(pmrt_finalize() == ATL_STATUS_SUCCESS, + "~pmi_resizable_simple_internal: failed"); + } } diff --git a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.h b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.h index 566e5b371..00ed1ac33 100644 --- a/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.h +++ b/src/atl/util/pm/pmi_resizable_rt/pmi_resizable_simple_internal.h @@ -45,7 +45,7 @@ class pmi_resizable_simple_internal final : public ipmi { pmi_resizable_simple_internal(int total_rank_count, const std::vector& ranks, std::shared_ptr k, - const char* main_addr = nullptr); + const char* main_addr = ""); ~pmi_resizable_simple_internal() override; @@ -59,7 +59,7 @@ class pmi_resizable_simple_internal final : public ipmi { atl_status_t pmrt_wait_notification() override; - void pmrt_barrier() override; + atl_status_t pmrt_barrier() override; atl_status_t pmrt_kvs_put(char* kvs_key, int proc_idx, @@ -77,27 +77,28 @@ class pmi_resizable_simple_internal final : public ipmi { size_t get_local_thread_idx() override; - size_t get_local_kvs_id() override; + atl_status_t get_local_kvs_id(size_t& res) override; - void set_local_kvs_id(size_t local_kvs_id) override; + atl_status_t set_local_kvs_id(size_t local_kvs_id) override; size_t get_threads_per_process() override; size_t get_ranks_per_process() override; - void pmrt_finalize() override; + atl_status_t pmrt_finalize() override; + + atl_status_t pmrt_init() override; private: bool is_finalized{ false }; - atl_status_t pmrt_init(const char* main_addr = nullptr); int kvs_set_value(const char* kvs_name, const char* key, const char* value); - int kvs_get_value(const char* kvs_name, const char* key, char* value); + atl_status_t kvs_get_value(const char* kvs_name, const char* key, char* value); - void pmrt_barrier_full(); - void barrier_full_reg(); - void barrier_reg(); - void registration(); + atl_status_t pmrt_barrier_full(); + atl_status_t barrier_full_reg(); + atl_status_t barrier_reg(); + atl_status_t registration(); int proc_count = 0; int rank = 0; @@ -109,6 +110,7 @@ class pmi_resizable_simple_internal final : public ipmi { std::vector ranks; std::shared_ptr k; + std::string main_addr; size_t max_keylen; size_t max_vallen; char* val_storage = nullptr; diff --git a/src/atl/util/pm/pmi_rt/pmi_simple.cpp b/src/atl/util/pm/pmi_rt/pmi_simple.cpp index 6e14b8529..9526f900a 100644 --- a/src/atl/util/pm/pmi_rt/pmi_simple.cpp +++ b/src/atl/util/pm/pmi_rt/pmi_simple.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "common/log/log.hpp" #include "pmi_simple.h" #include "pmi_rt.c" @@ -20,37 +21,41 @@ int pmi_simple::is_pm_resize_enabled() { return false; } -pmi_simple::pmi_simple() { - pmirt_init(&rank, &size, &pmrt_desc); +pmi_simple::pmi_simple() {} + +atl_status_t pmi_simple::pmrt_init() { + return pmirt_init(&rank, &size, &pmrt_desc); } atl_status_t pmi_simple::pmrt_main_addr_reserve(char *main_addr) { - printf("Function main_addr_reserv unsupported yet for simple pmi\n"); + LOG_ERROR("Function main_addr_reserv unsupported yet for simple pmi\n"); return ATL_STATUS_FAILURE; } atl_status_t pmi_simple::pmrt_set_resize_function(atl_resize_fn_t resize_fn) { - printf("Function set_resize_function unsupported yet for simple pmi\n"); + LOG_ERROR("Function set_resize_function unsupported yet for simple pmi\n"); return ATL_STATUS_FAILURE; } atl_status_t pmi_simple::pmrt_update() { - printf("Function update unsupported yet for simple pmi\n"); + LOG_ERROR("Function update unsupported yet for simple pmi\n"); return ATL_STATUS_FAILURE; } atl_status_t pmi_simple::pmrt_wait_notification() { - printf("Function wait_notification unsupported yet for simple pmi\n"); + LOG_ERROR("Function wait_notification unsupported yet for simple pmi\n"); return ATL_STATUS_FAILURE; } -void pmi_simple::pmrt_finalize() { +atl_status_t pmi_simple::pmrt_finalize() { is_finalized = true; pmirt_finalize(pmrt_desc); + return ATL_STATUS_SUCCESS; } -void pmi_simple::pmrt_barrier() { +atl_status_t pmi_simple::pmrt_barrier() { pmirt_barrier(pmrt_desc); + return ATL_STATUS_SUCCESS; } atl_status_t pmi_simple::pmrt_kvs_put(char *kvs_key, @@ -78,12 +83,16 @@ int pmi_simple::get_size() { size_t pmi_simple::get_local_thread_idx() { return 0; } -size_t pmi_simple::get_local_kvs_id() { - return 0; +atl_status_t pmi_simple::get_local_kvs_id(size_t &res) { + res = 0; + return ATL_STATUS_SUCCESS; +} +atl_status_t pmi_simple::set_local_kvs_id(size_t local_kvs_id) { + return ATL_STATUS_SUCCESS; } -void pmi_simple::set_local_kvs_id(size_t local_kvs_id) {} pmi_simple::~pmi_simple() { - if (!is_finalized) - pmrt_finalize(); + if (!is_finalized) { + CCL_THROW_IF_NOT(pmrt_finalize() == ATL_STATUS_SUCCESS, "~pmi_simple: failed"); + } } diff --git a/src/atl/util/pm/pmi_rt/pmi_simple.h b/src/atl/util/pm/pmi_rt/pmi_simple.h index 27d8b0571..8ce68407b 100644 --- a/src/atl/util/pm/pmi_rt/pmi_simple.h +++ b/src/atl/util/pm/pmi_rt/pmi_simple.h @@ -31,9 +31,9 @@ class pmi_simple final : public ipmi { atl_status_t pmrt_wait_notification() override; - void pmrt_finalize() override; + atl_status_t pmrt_finalize() override; - void pmrt_barrier() override; + atl_status_t pmrt_barrier() override; atl_status_t pmrt_kvs_put(char *kvs_key, int proc_idx, @@ -51,9 +51,9 @@ class pmi_simple final : public ipmi { size_t get_local_thread_idx() override; - size_t get_local_kvs_id() override; + atl_status_t get_local_kvs_id(size_t &res) override; - void set_local_kvs_id(size_t local_kvs_id) override; + atl_status_t set_local_kvs_id(size_t local_kvs_id) override; size_t get_threads_per_process() override { return 1; @@ -63,6 +63,8 @@ class pmi_simple final : public ipmi { return 1; } + atl_status_t pmrt_init() override; + private: int rank; int size; diff --git a/src/ccl_api_functions.cpp b/src/ccl_api_functions.cpp index 4ae113e2b..1a48173cd 100644 --- a/src/ccl_api_functions.cpp +++ b/src/ccl_api_functions.cpp @@ -16,12 +16,12 @@ #include "oneapi/ccl/types.hpp" #include "oneapi/ccl/environment.hpp" #include "oneapi/ccl/api_functions.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" +#include "common/comm/comm.hpp" #include "oneapi/ccl/exception.hpp" -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) +#if defined(CCL_ENABLE_ZE) || defined(CCL_ENABLE_SYCL) #include "common/comm/comm_interface.hpp" -#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) +#endif //#if defined(CCL_ENABLE_ZE) || defined(CCL_ENABLE_SYCL) #include "ccl_api_functions_generators.hpp" #include "common/global/global.hpp" diff --git a/src/ccl_cpp_communicator.cpp b/src/ccl_cpp_communicator.cpp index 7e608b59b..af93c5dac 100644 --- a/src/ccl_cpp_communicator.cpp +++ b/src/ccl_cpp_communicator.cpp @@ -46,14 +46,12 @@ #include "oneapi/ccl/event.hpp" #include "oneapi/ccl/communicator.hpp" -#include "common/comm/l0/comm_context_storage.hpp" #include "common/global/global.hpp" //TODO #include "common/comm/comm.hpp" -#include "common/comm/l0/comm_context.hpp" #include "communicator_impl.hpp" namespace ccl { diff --git a/src/ccl_cpp_environment.cpp b/src/ccl_cpp_environment.cpp index 7c7297170..37fb94813 100644 --- a/src/ccl_cpp_environment.cpp +++ b/src/ccl_cpp_environment.cpp @@ -18,10 +18,9 @@ #include "exec/exec.hpp" #include "common/utils/version.hpp" -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) -#include "common/comm/l0/comm_context.hpp" +#if defined(CCL_ENABLE_ZE) || defined(CCL_ENABLE_SYCL) #include "common/comm/comm_interface.hpp" -#endif //#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) +#endif //#if defined(CCL_ENABLE_ZE) || defined(CCL_ENABLE_SYCL) #include diff --git a/src/ccl_cpp_kvs.cpp b/src/ccl_cpp_kvs.cpp index 600e523c3..765f9f36b 100644 --- a/src/ccl_cpp_kvs.cpp +++ b/src/ccl_cpp_kvs.cpp @@ -53,8 +53,10 @@ kvs::address_type kvs_impl::get_addr() { } vector_class kvs_impl::get(const string_class& key) { - char ret[128]; - inter_kvs->kvs_get_value_by_name_key(prefix.c_str(), key.c_str(), ret); + char ret[MAX_KVS_VAL_LENGTH]; + CCL_THROW_IF_NOT(inter_kvs->kvs_get_value_by_name_key(prefix.c_str(), key.c_str(), ret) == + KVS_STATUS_SUCCESS, + "kvs get failed"); size_t ret_len = strlen(ret); vector_class ret_vec; if (ret_len != 0) { diff --git a/src/coll/algorithms/algorithm_utils.cpp b/src/coll/algorithms/algorithm_utils.cpp index 48b5c00cf..04ecd4525 100644 --- a/src/coll/algorithms/algorithm_utils.cpp +++ b/src/coll/algorithms/algorithm_utils.cpp @@ -13,16 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "coll/algorithms/algorithms_enum.hpp" +#include +#include +#include -bool ccl_coll_type_is_reduction(ccl_coll_type ctype) { - switch (ctype) { - case ccl_coll_allreduce: - case ccl_coll_reduce: - case ccl_coll_reduce_scatter: return true; - default: return false; - } -} +#include "coll/algorithms/algorithm_utils.hpp" +#include "common/log/log.hpp" const char* ccl_coll_type_to_str(ccl_coll_type type) { switch (type) { @@ -35,9 +31,56 @@ const char* ccl_coll_type_to_str(ccl_coll_type type) { case ccl_coll_reduce: return "reduce"; case ccl_coll_reduce_scatter: return "reduce_scatter"; case ccl_coll_sparse_allreduce: return "sparse_allreduce"; - case ccl_coll_internal: return "internal"; case ccl_coll_partial: return "partial"; + case ccl_coll_undefined: return "undefined"; default: return "unknown"; } return "unknown"; } + +void ccl_get_segment_sizes(size_t dtype_size, + size_t elem_count, + size_t requested_seg_size, + std::vector& seg_sizes) { + seg_sizes.clear(); + + if (dtype_size * elem_count == 0) { + return; + } + else if (dtype_size >= requested_seg_size) { + seg_sizes.resize(elem_count, 1); + } + else { + size_t seg_size = (requested_seg_size + dtype_size - 1) / dtype_size; + size_t total_seg_count = std::max((elem_count + seg_size - 1) / seg_size, 1UL); + size_t regular_seg_size = elem_count / total_seg_count; + size_t large_seg_size = regular_seg_size + ((elem_count % total_seg_count) != 0); + size_t regular_seg_count = total_seg_count * large_seg_size - elem_count; + + seg_sizes.resize(total_seg_count, regular_seg_size); + std::fill(seg_sizes.begin() + regular_seg_count, seg_sizes.end(), large_seg_size); + + size_t sum = std::accumulate(seg_sizes.begin(), seg_sizes.end(), 0); + if (sum != elem_count) { + std::stringstream ss; + for (size_t idx = 0; idx < seg_sizes.size(); idx++) { + ss << seg_sizes[idx] << " "; + } + CCL_THROW_IF_NOT(false, + "unexpected sum of seg_sizes ", + sum, + ", expected ", + elem_count, + ", total_seg_count ", + total_seg_count, + ", regular_seg_count ", + regular_seg_count, + ", regular_seg_size ", + regular_seg_size, + ", large_seg_size ", + large_seg_size, + ", all seg_sizes: ", + ss.str()); + } + } +} diff --git a/src/coll/algorithms/algorithms_enum.hpp b/src/coll/algorithms/algorithm_utils.hpp similarity index 64% rename from src/coll/algorithms/algorithms_enum.hpp rename to src/coll/algorithms/algorithm_utils.hpp index 880991e44..f90b43b2d 100644 --- a/src/coll/algorithms/algorithms_enum.hpp +++ b/src/coll/algorithms/algorithm_utils.hpp @@ -14,8 +14,10 @@ limitations under the License. */ #pragma once -#include "common/utils/enums.hpp" +#include + +#include "common/utils/enums.hpp" #include "oneapi/ccl/types.hpp" #define CCL_COLL_LIST \ @@ -30,7 +32,8 @@ enum ccl_coll_allgatherv_algo { ccl_coll_allgatherv_naive, ccl_coll_allgatherv_ring, ccl_coll_allgatherv_flat, - ccl_coll_allgatherv_multi_bcast + ccl_coll_allgatherv_multi_bcast, + ccl_coll_allgatherv_topo }; enum ccl_coll_allreduce_algo { @@ -38,13 +41,13 @@ enum ccl_coll_allreduce_algo { ccl_coll_allreduce_direct, ccl_coll_allreduce_rabenseifner, - ccl_coll_allreduce_starlike, + ccl_coll_allreduce_nreduce, ccl_coll_allreduce_ring, ccl_coll_allreduce_ring_rma, ccl_coll_allreduce_double_tree, ccl_coll_allreduce_recursive_doubling, ccl_coll_allreduce_2d, - ccl_coll_allreduce_topo_ring + ccl_coll_allreduce_topo }; enum ccl_coll_alltoall_algo { @@ -79,7 +82,7 @@ enum ccl_coll_bcast_algo { ccl_coll_bcast_ring, ccl_coll_bcast_double_tree, ccl_coll_bcast_naive, - ccl_coll_bcast_topo_ring + ccl_coll_bcast_topo }; enum ccl_coll_reduce_algo { @@ -89,14 +92,15 @@ enum ccl_coll_reduce_algo { ccl_coll_reduce_rabenseifner, ccl_coll_reduce_tree, ccl_coll_reduce_double_tree, - ccl_coll_reduce_topo_ring + ccl_coll_reduce_topo }; enum ccl_coll_reduce_scatter_algo { ccl_coll_reduce_scatter_undefined = 0, ccl_coll_reduce_scatter_direct, - ccl_coll_reduce_scatter_ring + ccl_coll_reduce_scatter_ring, + ccl_coll_reduce_scatter_topo }; enum ccl_coll_sparse_allreduce_algo { @@ -136,54 +140,15 @@ enum ccl_coll_type { ccl_coll_sparse_allreduce, ccl_coll_last_regular = ccl_coll_sparse_allreduce, - ccl_coll_internal, ccl_coll_partial, + ccl_coll_undefined, ccl_coll_last_value }; -// Currently ccl_coll_type is used in both compile-time and run-time contexts, so -// need to have both versions of the check. -// It's possible to have a constexpr function, but it requires some features from c++14 -// (e.g. multiple returns in constexpr functions) - -template -struct is_reduction_coll_type : std::false_type {}; - -// Reduction types -template -struct is_reduction_coll_type< - ctype, - typename std::enable_if::type> : std::true_type {}; - -bool ccl_coll_type_is_reduction(ccl_coll_type ctype); const char* ccl_coll_type_to_str(ccl_coll_type type); -#define CCL_COLL_TYPE_LIST \ - ccl_coll_type::ccl_coll_allgatherv, ccl_coll_type::ccl_coll_allreduce, \ - ccl_coll_type::ccl_coll_alltoall, ccl_coll_type::ccl_coll_alltoallv, \ - ccl_coll_type::ccl_coll_barrier, ccl_coll_type::ccl_coll_bcast, \ - ccl_coll_type::ccl_coll_reduce, ccl_coll_type::ccl_coll_reduce_scatter, \ - ccl_coll_type::ccl_coll_sparse_allreduce - -enum ccl_coll_reduction { - sum, - prod, - min, - max, - //custom, TODO: make support of custom reduction in *.cl - - last_value -}; - -#define REDUCE_TYPES \ - ccl::reduction::sum, ccl::reduction::prod, ccl::reduction::min, \ - ccl::reduction::max /*, ccl::reduction::custom*/ - -using ccl_reductions = - utils::enum_to_str::type>( - ccl::reduction::custom)>; -inline const std::string reduction_to_str(ccl::reduction reduction_type) { - return ccl_reductions({ "sum", "prod", "min", "max" }).choose(reduction_type, "INVALID_VALUE"); -} +void ccl_get_segment_sizes(size_t dtype_size, + size_t elem_count, + size_t requested_seg_size, + std::vector& seg_sizes); diff --git a/src/coll/algorithms/algorithms.hpp b/src/coll/algorithms/algorithms.hpp index 712de99f8..0fdce349c 100644 --- a/src/coll/algorithms/algorithms.hpp +++ b/src/coll/algorithms/algorithms.hpp @@ -38,14 +38,14 @@ ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, int root, ccl_comm* comm); -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) ccl::status ccl_coll_build_gpu_bcast(ccl_sched* sched, ccl_buffer buf, size_t count, const ccl_datatype& dtype, int root, ccl_comm* comm); -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE ccl::status ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* comm); @@ -58,7 +58,7 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, int root, ccl_comm* comm); -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) ccl::status ccl_coll_build_gpu_reduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, @@ -67,7 +67,7 @@ ccl::status ccl_coll_build_gpu_reduce(ccl_sched* sched, ccl::reduction reduction, int root, ccl_comm* comm); -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, ccl_buffer send_buf, @@ -110,23 +110,24 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, ccl::reduction reduction, ccl_comm* comm); -ccl::status ccl_coll_build_starlike_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); +ccl::status ccl_coll_build_nreduce_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) -ccl::status ccl_coll_build_gpu_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction reduction, - ccl_comm* comm); -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) +ccl::status ccl_coll_build_topo_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); + +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE ccl::status ccl_coll_build_naive_allgatherv(ccl_sched* sched, ccl_buffer send_buf, @@ -225,6 +226,14 @@ ccl::status ccl_coll_build_multi_bcast_allgatherv(ccl_master_sched* main_sched, const ccl_coll_param& coll_param, size_t data_partition_count); +ccl::status ccl_coll_build_topo_allgatherv(ccl_sched* sched, + ccl_buffer send_buf, + size_t send_count, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm); + ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, std::vector& scheds, const ccl_coll_param& coll_param); @@ -295,3 +304,11 @@ ccl::status ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, const ccl_datatype& dtype, ccl::reduction reduction, ccl_comm* comm); + +ccl::status ccl_coll_build_topo_reduce_scatter(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t send_count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm); diff --git a/src/coll/algorithms/allgatherv.cpp b/src/coll/algorithms/allgatherv.cpp index 98d1bd160..e43a7474b 100644 --- a/src/coll/algorithms/allgatherv.cpp +++ b/src/coll/algorithms/allgatherv.cpp @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "coll/algorithms/algorithms.hpp" +#include "common/comm/comm.hpp" #include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" +#if defined(CCL_ENABLE_ZE) && defined(CCL_ENABLE_SYCL) +#include "coll/coll_util.hpp" +#endif // CCL_ENABLE_ZE && CCL_ENABLE_SYCL + +#include ccl::status ccl_coll_build_direct_allgatherv(ccl_sched* sched, ccl_buffer send_buf, @@ -29,7 +33,7 @@ ccl::status ccl_coll_build_direct_allgatherv(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct allgatherv"); - entry_factory::make_entry( + entry_factory::create( sched, send_buf, send_count, recv_buf, recv_counts, dtype, comm); return ccl::status::success; } @@ -43,35 +47,37 @@ ccl::status ccl_coll_build_naive_allgatherv(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build naive allgatherv"); + ccl::status status = ccl::status::success; + int comm_size = comm->size(); - int this_rank = comm->rank(); + int comm_rank = comm->rank(); size_t dtype_size = dtype.size(); - size_t* offsets = static_cast(CCL_MALLOC(comm_size * sizeof(size_t), "offsets")); - ccl::status status = ccl::status::success; + std::vector offsets(comm_size); offsets[0] = 0; - for (int rank_idx = 1; rank_idx < comm_size; ++rank_idx) { - offsets[rank_idx] = offsets[rank_idx - 1] + recv_counts[rank_idx - 1] * dtype_size; + for (int rank = 1; rank < comm_size; rank++) { + offsets[rank] = offsets[rank - 1] + recv_counts[rank - 1] * dtype_size; } if (send_buf != recv_buf) { // out-of-place case - entry_factory::make_entry( - sched, send_buf, recv_buf + offsets[this_rank], send_count, dtype); + entry_factory::create( + sched, send_buf, recv_buf + offsets[comm_rank], send_count, dtype); } - for (int rank_idx = 0; rank_idx < comm_size; ++rank_idx) { - if (rank_idx != this_rank) { - // send own buffer to other ranks - entry_factory::make_chunked_send_entry( - sched, recv_buf + offsets[this_rank], send_count, dtype, rank_idx, comm); - // recv other's rank buffer - entry_factory::make_chunked_recv_entry( - sched, recv_buf + offsets[rank_idx], recv_counts[rank_idx], dtype, rank_idx, comm); - } + for (int idx = 1; idx < comm_size; idx++) { + int dst = (comm_rank + idx) % comm_size; + int src = (comm_rank - idx + comm_size) % comm_size; + + // send own buffer to other ranks + entry_factory::create( + sched, recv_buf + offsets[comm_rank], send_count, dtype, dst, comm); + + // recv other's rank buffer + entry_factory::create( + sched, recv_buf + offsets[src], recv_counts[src], dtype, src, comm); } - CCL_FREE(offsets); return status; } @@ -99,7 +105,7 @@ ccl::status ccl_coll_build_ring_allgatherv(ccl_sched* sched, } if (send_buf != recv_buf) { - entry_factory::make_entry( + entry_factory::create( sched, send_buf, recv_buf + offsets[rank], send_count, dtype); } @@ -124,8 +130,8 @@ ccl::status ccl_coll_build_ring_allgatherv(ccl_sched* sched, sbuf = recv_buf + send_block_offset; rbuf = recv_buf + recv_block_offset; - entry_factory::make_entry(sched, sbuf, send_block_count, dtype, dst, comm); - entry_factory::make_entry(sched, rbuf, recv_block_count, dtype, src, comm); + entry_factory::create(sched, sbuf, send_block_count, dtype, dst, comm); + entry_factory::create(sched, rbuf, recv_block_count, dtype, src, comm); sched->add_barrier(); block_idx = (comm_size + block_idx - 1) % comm_size; // move left @@ -197,13 +203,13 @@ ccl::status ccl_coll_build_flat_allgatherv(ccl_master_sched* main_sched, ccl_buffer_type::INDIRECT); if (!inplace) { - entry_factory::make_entry(scheds[2 * comm_rank % sched_count], - ccl_buffer(coll_param.get_send_buf_ptr(), - coll_param.get_send_count() * dtype_size, - ccl_buffer_type::INDIRECT), - recv_bufs[comm_rank], - coll_param.get_recv_count(comm_rank), - dtype); + entry_factory::create(scheds[2 * comm_rank % sched_count], + ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + ccl_buffer_type::INDIRECT), + recv_bufs[comm_rank], + coll_param.get_recv_count(comm_rank), + dtype); } else { size_t total_recv_bytes = @@ -225,19 +231,19 @@ ccl::status ccl_coll_build_flat_allgatherv(ccl_master_sched* main_sched, if (static_cast(idx) == comm_rank) continue; - entry_factory::make_entry(scheds[(comm_rank + idx) % sched_count], - recv_bufs[idx], - coll_param.get_recv_count(idx), - dtype, - idx, - comm); - - entry_factory::make_entry(scheds[(comm_rank + idx) % sched_count], - send_seg, - coll_param.get_recv_count(comm_rank), - dtype, - idx, - comm); + entry_factory::create(scheds[(comm_rank + idx) % sched_count], + recv_bufs[idx], + coll_param.get_recv_count(idx), + dtype, + idx, + comm); + + entry_factory::create(scheds[(comm_rank + idx) % sched_count], + send_seg, + coll_param.get_recv_count(comm_rank), + dtype, + idx, + comm); } main_sched->sync_partial_scheds(); @@ -279,15 +285,14 @@ ccl::status ccl_coll_build_multi_bcast_allgatherv(ccl_master_sched* main_sched, CCL_ASSERT(scheds.size() >= data_partition_count); for (size_t idx = 0; idx < data_partition_count; idx++) { - entry_factory::make_entry( - scheds[idx], - ccl_buffer(coll_param.get_send_buf_ptr(), - coll_param.get_send_count() * dtype_size, - copy_offsets[idx], - ccl_buffer_type::INDIRECT), - recv_bufs[comm_rank] + copy_offsets[idx], - copy_counts[idx], - dtype); + entry_factory::create(scheds[idx], + ccl_buffer(coll_param.get_send_buf_ptr(), + coll_param.get_send_count() * dtype_size, + copy_offsets[idx], + ccl_buffer_type::INDIRECT), + recv_bufs[comm_rank] + copy_offsets[idx], + copy_counts[idx], + dtype); } main_sched->sync_partial_scheds(); } @@ -306,3 +311,164 @@ ccl::status ccl_coll_build_multi_bcast_allgatherv(ccl_master_sched* main_sched, return ccl::status::success; } + +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + +ccl::status ccl_coll_build_topo_allgatherv(ccl_sched* sched, + ccl_buffer send_buf, + size_t send_count, + ccl_buffer recv_buf, + const size_t* recv_counts, + const ccl_datatype& dtype, + ccl_comm* comm) { + LOG_DEBUG("build topo allgatherv"); + + ccl_comm* pair_comm = comm->get_pair_comm().get(); + ccl_comm* even_comm = comm->get_even_comm().get(); + ccl_comm* node_comm = comm->get_node_comm().get(); + ccl_comm* r2r_comm = comm->get_r2r_comm().get(); + + int comm_size = comm->size(); + int pair_comm_size = pair_comm->size(); + int node_comm_size = node_comm->size(); + int r2r_comm_size = r2r_comm->size(); + + bool is_inplace = send_buf == recv_buf; + bool is_single_node = comm_size == node_comm_size; + + const std::vector in_buffers{ + { send_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 + { recv_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 1 + }; + + size_t send_buf_idx = 0; + size_t recv_buf_idx = 1; + + ccl::add_handle_exchange(sched, node_comm, in_buffers); + + if (is_single_node) { + std::vector wait_events; + entry_factory::create(sched, + send_buf, + send_count, + recv_buf, + recv_counts, + dtype, + comm, + wait_events, + recv_buf_idx); + sched->add_barrier(); + ccl::add_comm_barrier(sched, comm); + return ccl::status::success; + } + + // helper function + auto get_distance = [&](int from, int to) { + CCL_THROW_IF_NOT(from >= 0, "from: ", from, " to: ", to); + CCL_THROW_IF_NOT(from <= to, "from: ", from, " to: ", to); + CCL_THROW_IF_NOT(to <= comm_size, "from: ", from, " to: ", to); + return std::accumulate(recv_counts + from, recv_counts + to, 0); + }; + + if (pair_comm->rank() == ccl::global_data::env().kernel_1s_lead) { + /* 1. allocate send && recv tmp host buffers for host bcast stage */ + int pair_start = pair_comm->get_global_rank(0, true); + size_t host_send_buf_count = get_distance(pair_start, pair_start + pair_comm_size); + size_t host_send_buf_bytes = host_send_buf_count * dtype.size(); + + size_t host_recv_buf_count{}; // calculate max pair size in recv_count + for (int rank = 0; rank < comm_size; rank += pair_comm_size) { + size_t count = get_distance(rank, rank + pair_comm_size); + host_recv_buf_count = std::max(host_recv_buf_count, count); + } + size_t host_recv_buf_bytes = host_recv_buf_count * dtype.size(); + + LOG_DEBUG("alloc host tmp buffers for bcast: send_buf: ", + host_send_buf_bytes, + ", recv_buf: ", + host_recv_buf_bytes); + ccl::alloc_param host_send_buf_alloc( + host_send_buf_bytes, ccl::buffer_type::regular, ccl::buffer_place::host); + ccl_buffer send_host_buf = sched->alloc_buffer(host_send_buf_alloc); + + ccl::alloc_param host_recv_buf_alloc( + host_recv_buf_bytes, ccl::buffer_type::regular, ccl::buffer_place::host); + ccl_buffer recv_host_buf = sched->alloc_buffer(host_recv_buf_alloc); + + /* 2. copy to host */ + for (int peer_rank = 0, dst_offset{}; peer_rank < pair_comm_size; ++peer_rank) { + int global_rank = pair_comm->get_global_rank(peer_rank, true) - + ccl::global_data::env().kernel_1s_lead; + size_t copy_count = recv_counts[global_rank]; + ccl_buffer src{}; + size_t src_offset = (is_inplace) ? get_distance(0, global_rank) : 0; + copy_attr attr( + peer_rank, send_buf_idx, copy_direction::d2h, pair_comm, src_offset, dst_offset); + if (peer_rank == pair_comm->rank()) { + src = send_buf; + attr = copy_attr(copy_direction::d2h, src_offset); + } + LOG_DEBUG("copy to host: from global rank: ", global_rank, ", count: ", copy_count); + entry_factory::create(sched, src, send_host_buf, copy_count, dtype, attr); + dst_offset += copy_count; + } + sched->add_barrier(); + + /* 3. bcast between nodes */ + for (int peer_rank = 0; peer_rank < r2r_comm_size; ++peer_rank) { + ccl_buffer buf = recv_host_buf; + if (peer_rank == r2r_comm->rank()) { + buf = send_host_buf; + } + + int global_rank = r2r_comm->get_global_rank(peer_rank, true); + int r2r_start = global_rank - ccl::global_data::env().kernel_1s_lead; + size_t copy_count = get_distance(r2r_start, r2r_start + pair_comm_size); + LOG_DEBUG("bcast: peer_rank: ", global_rank, ", count ", copy_count); + ccl_coll_build_bcast(sched, buf, copy_count, dtype, peer_rank, r2r_comm); + sched->add_barrier(); + + size_t dst_offset = get_distance(0, r2r_start); + LOG_DEBUG("copy to device: offset: ", dst_offset, ", count: ", copy_count); + entry_factory::create(sched, + buf, + recv_buf, + copy_count, + dtype, + copy_attr(copy_direction::h2d, 0, dst_offset)); + sched->add_barrier(); + } + ccl::add_comm_barrier(sched, even_comm); + + /* 4. allgatherv in even_comm */ + for (int node_idx = 0; node_idx < r2r_comm_size; ++node_idx) { + int from = (comm->rank() - ccl::global_data::env().kernel_1s_lead + + node_idx * node_comm_size) % + comm_size; // TODO: fix lead + int to = from + pair_comm_size; + size_t count = get_distance(from, to); + size_t offset = get_distance(0, from); + for (int i = 0; i < even_comm->size() - 1; ++i) { + int peer_rank = (even_comm->rank() + i + 1) % even_comm->size(); + copy_attr attr( + peer_rank, recv_buf_idx, copy_direction::d2d, even_comm, offset, offset); + entry_factory::create( + sched, recv_buf, ccl_buffer(), count, dtype, attr); + } + } + sched->add_barrier(); + ccl::add_comm_barrier(sched, even_comm); + + /* 5. copy to peer pair rank */ + size_t copy_count = get_distance(0, comm_size); + int peer_rank = (pair_comm->rank() + 1) % pair_comm_size; + copy_attr attr(peer_rank, recv_buf_idx, copy_direction::d2d, pair_comm); + entry_factory::create(sched, recv_buf, ccl_buffer(), copy_count, dtype, attr); + sched->add_barrier(); + } + ccl::add_comm_barrier(sched, pair_comm); + + return ccl::status::success; +} + +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE diff --git a/src/coll/algorithms/allreduce/allreduce.cpp b/src/coll/algorithms/allreduce/allreduce.cpp index 89c6d36fb..9b02a8695 100644 --- a/src/coll/algorithms/allreduce/allreduce.cpp +++ b/src/coll/algorithms/allreduce/allreduce.cpp @@ -21,10 +21,15 @@ */ #include "coll/algorithms/algorithms.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" +#include "coll/algorithms/algorithm_utils.hpp" +#include "common/comm/comm.hpp" #include "sched/entry/coll/coll_entry_helper.hpp" +#include "sched/entry/copy/copy_helper.hpp" #include "sched/entry/factory/chunked_entry_factory.hpp" #include "sched/entry/factory/entry_factory.hpp" +#if defined(CCL_ENABLE_ZE) && defined(CCL_ENABLE_SYCL) +#include "coll/coll_util.hpp" +#endif // CCL_ENABLE_ZE && CCL_ENABLE_SYCL ccl::status ccl_coll_build_direct_allreduce(ccl_sched* sched, ccl_buffer send_buf, @@ -35,7 +40,7 @@ ccl::status ccl_coll_build_direct_allreduce(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct allreduce"); - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype, op, comm); + entry_factory::create(sched, send_buf, recv_buf, count, dtype, op, comm); return ccl::status::success; } @@ -57,12 +62,12 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, comm_size = comm->size(); rank = comm->rank(); - ccl_buffer tmp_buf = sched->alloc_buffer(count * dtype_size); + ccl_buffer tmp_buf = sched->alloc_buffer({ count * dtype_size, send_buf }); /* copy local data into recv_buf */ if (send_buf != recv_buf) { - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } @@ -82,7 +87,7 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank + 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank + 1, comm); sched->add_barrier(); /* temporarily set the rank to -1 so that this @@ -91,13 +96,13 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, newrank = CCL_INVALID_PROC_IDX; } else { /* odd */ - entry_factory::make_entry(sched, tmp_buf, count, dtype, rank - 1, comm); + entry_factory::create(sched, tmp_buf, count, dtype, rank - 1, comm); sched->add_barrier(); /* do the reduction on received data. since the * ordering is right, it doesn't matter whether * the operation is commutative or not. */ - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf, count, recv_buf, nullptr, dtype, op); sched->add_barrier(); @@ -160,27 +165,24 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, CCL_ASSERT(can_use_recv_reduce); if (can_use_recv_reduce) { - entry_factory::make_entry( - sched, - (recv_buf + disps[recv_idx] * dtype_size), - recv_cnt, - nullptr, - dtype, - op, - dst, - ccl_buffer(), - comm); - entry_factory::make_entry( + entry_factory::create(sched, + (recv_buf + disps[recv_idx] * dtype_size), + recv_cnt, + dtype, + op, + dst, + comm); + entry_factory::create( sched, (recv_buf + disps[send_idx] * dtype_size), send_cnt, dtype, dst, comm); sched->add_barrier(); } else { /* Send data from recv_buf. Recv into tmp_buf */ - entry_factory::make_entry( + entry_factory::create( sched, (tmp_buf + disps[recv_idx] * dtype_size), recv_cnt, dtype, dst, comm); /* sendrecv, no barrier here */ - entry_factory::make_entry( + entry_factory::create( sched, (recv_buf + disps[send_idx] * dtype_size), send_cnt, dtype, dst, comm); sched->add_barrier(); @@ -189,14 +191,13 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, /* This algorithm is used only for predefined ops * and predefined ops are always commutative. */ - entry_factory::make_entry( - sched, - (tmp_buf + disps[recv_idx] * dtype_size), - recv_cnt, - (recv_buf + disps[recv_idx] * dtype_size), - nullptr, - dtype, - op); + entry_factory::create(sched, + (tmp_buf + disps[recv_idx] * dtype_size), + recv_cnt, + (recv_buf + disps[recv_idx] * dtype_size), + nullptr, + dtype, + op); sched->add_barrier(); } @@ -239,10 +240,10 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, recv_cnt += cnts[i]; } - entry_factory::make_entry( + entry_factory::create( sched, (recv_buf + disps[recv_idx] * dtype_size), recv_cnt, dtype, dst, comm); /* sendrecv, no barrier here */ - entry_factory::make_entry( + entry_factory::create( sched, (recv_buf + disps[send_idx] * dtype_size), send_cnt, dtype, dst, comm); sched->add_barrier(); @@ -258,10 +259,10 @@ ccl::status ccl_coll_build_rabenseifner_allreduce(ccl_sched* sched, * (rank-1), the ranks who didn't participate above. */ if (rank < 2 * rem) { if (rank % 2) { /* odd */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank - 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank - 1, comm); } else { /* even */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank + 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank + 1, comm); } } @@ -290,11 +291,11 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, size_t dtype_size = dtype.size(); - ccl_buffer tmp_buf = sched->alloc_buffer(count * dtype_size); + ccl_buffer tmp_buf = sched->alloc_buffer({ count * dtype_size, send_buf }); /* copy local data into recv_buf */ if (send_buf != recv_buf) { - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } @@ -313,7 +314,7 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank + 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank + 1, comm); sched->add_barrier(); /* temporarily set the rank to -1 so that this @@ -322,14 +323,14 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, newrank = -1; } else { /* odd */ - entry_factory::make_entry(sched, tmp_buf, count, dtype, rank - 1, comm); + entry_factory::create(sched, tmp_buf, count, dtype, rank - 1, comm); sched->add_barrier(); /* do the reduction on received data. since the * ordering is right, it doesn't matter whether * the operation is commutative or not. */ - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf, count, recv_buf, nullptr, dtype, op); sched->add_barrier(); @@ -349,14 +350,14 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, /* Send the most current data, which is in recv_buf. Recv * into tmp_buf */ - entry_factory::make_entry(sched, tmp_buf, count, dtype, dst, comm); + entry_factory::create(sched, tmp_buf, count, dtype, dst, comm); /* sendrecv, no barrier here */ - entry_factory::make_entry(sched, recv_buf, count, dtype, dst, comm); + entry_factory::create(sched, recv_buf, count, dtype, dst, comm); sched->add_barrier(); /* tmp_buf contains data received in this step. * recv_buf contains data accumulated so far */ - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf, count, recv_buf, nullptr, dtype, op); sched->add_barrier(); @@ -369,10 +370,10 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, * (rank-1), the ranks who didn't participate above. */ if (rank < 2 * rem) { if (rank % 2) { /* odd */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank - 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank - 1, comm); } else { /* even */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank + 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank + 1, comm); } sched->add_barrier(); } @@ -380,82 +381,145 @@ ccl::status ccl_coll_build_recursive_doubling_allreduce(ccl_sched* sched, return status; } -ccl::status ccl_coll_build_starlike_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { - LOG_DEBUG("build starlike allreduce"); +ccl::status ccl_coll_build_nreduce_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { + LOG_DEBUG("build nreduce allreduce"); ccl::status status = ccl::status::success; int comm_size = comm->size(); - int this_rank = comm->rank(); - size_t* buffer_counts = - static_cast(CCL_MALLOC(comm_size * sizeof(size_t), "buffer_count")); - size_t* buffer_offsets = - static_cast(CCL_MALLOC(comm_size * sizeof(size_t), "buffer_offsets")); + int comm_rank = comm->rank(); + std::vector elem_counts(comm_size); + std::vector elem_offsets(comm_size); size_t dtype_size = dtype.size(); + bool is_inplace = (send_buf == recv_buf); - // copy local data into recv_buf - if (send_buf != recv_buf) { - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); - sched->add_barrier(); + if (comm_size == 1) { + if (!is_inplace) { + entry_factory::create(sched, send_buf, recv_buf, count, dtype); + } + return status; } - if (comm_size == 1) - return status; + int use_buffering = ccl::global_data::env().allreduce_nreduce_buffering; - // calculate counts and offsets for each rank - size_t common_buffer_count = count / comm_size; - for (int rank_idx = 0; rank_idx < comm_size; ++rank_idx) { - buffer_counts[rank_idx] = common_buffer_count; - buffer_offsets[rank_idx] = rank_idx * buffer_counts[rank_idx] * dtype_size; + size_t segment_size = 2 * 1024 * 1024; + if (ccl::global_data::env().allreduce_nreduce_segment_size != CCL_ENV_SIZET_NOT_SPECIFIED) { + segment_size = ccl::global_data::env().allreduce_nreduce_segment_size; } - buffer_counts[comm_size - 1] += count % comm_size; - - // recv_reduce buffer for current rank - size_t this_rank_buf_size = buffer_counts[this_rank] * dtype_size; - - ccl_buffer tmp_buf; - if (this_rank_buf_size) - tmp_buf = sched->alloc_buffer(this_rank_buf_size * (comm_size - 1)); - - size_t tmp_buf_recv_idx = 0; - for (int rank_idx = 0; rank_idx < comm_size; ++rank_idx) { - if (rank_idx != this_rank) { - // send buffer to others - entry_factory::make_chunked_send_entry(sched, - recv_buf + buffer_offsets[rank_idx], - buffer_counts[rank_idx], - dtype, - rank_idx, - comm); - - // recv part of buffer from others and perform reduce - entry_factory::make_chunked_recv_reduce_entry( - sched, - recv_buf + buffer_offsets[this_rank], - buffer_counts[this_rank], - nullptr, - dtype, - op, - rank_idx, - tmp_buf + this_rank_buf_size * tmp_buf_recv_idx, - comm); - ++tmp_buf_recv_idx; + + std::vector segment_sizes; + ccl_get_segment_sizes(dtype_size, count, segment_size, segment_sizes); + + size_t tmp_buf_size = *segment_sizes.rbegin() * comm_size * dtype_size * 2; + ccl_buffer tmp_buf = sched->alloc_buffer({ tmp_buf_size, send_buf }); + + size_t seg_offset = 0; + + for (size_t seg_idx = 0; seg_idx < segment_sizes.size(); seg_idx++) { + size_t seg_size = segment_sizes[seg_idx]; + + ccl_buffer seg_send_buf = send_buf + seg_offset; + ccl_buffer seg_recv_buf = recv_buf + seg_offset; + ccl_buffer seg_tmp_buf = tmp_buf + (seg_idx % 2) * (tmp_buf_size / 2); + + seg_offset += seg_size * dtype_size; + + // calculate counts and offsets for each rank + size_t common_buffer_count = seg_size / comm_size; + for (int idx = 0; idx < comm_size; idx++) { + elem_counts[idx] = common_buffer_count; + elem_offsets[idx] = idx * elem_counts[idx] * dtype_size; } - } + elem_counts[comm_size - 1] += seg_size % comm_size; - sched->add_barrier(); + size_t elem_count = elem_counts[comm_rank]; + + ccl_buffer reduce_buf; + if (use_buffering) { + reduce_buf = seg_tmp_buf + elem_count * comm_rank * dtype_size; + } + else { + reduce_buf = seg_recv_buf + elem_offsets[comm_rank]; + } - // allgatherv - CCL_CALL(ccl_coll_build_naive_allgatherv( - sched, recv_buf, buffer_counts[this_rank], recv_buf, buffer_counts, dtype, comm)); + if (!is_inplace || use_buffering) { + entry_factory::create(sched, + seg_send_buf + elem_offsets[comm_rank], + reduce_buf, + elem_counts[comm_rank], + dtype); + sched->add_barrier(); + } + + // reduce-scatter + for (int idx = 1; idx < comm_size; idx++) { + int dst = (comm_rank - idx + comm_size) % comm_size; - CCL_FREE(buffer_counts); - CCL_FREE(buffer_offsets); + // send part of buffer to other rank + entry_factory::create( + sched, seg_send_buf + elem_offsets[dst], elem_counts[dst], dtype, dst, comm); + } + + for (int idx = 1; idx < comm_size; idx++) { + int src = (comm_rank + idx) % comm_size; + + // recv part of buffer from other rank and perform reduce + entry_factory::create(sched, + reduce_buf, + elem_count, + dtype, + op, + src, + comm, + seg_tmp_buf + elem_count * src * dtype_size); + } + + sched->add_barrier(); + + // allgatherv + if (use_buffering) { + copy_attr attr; + attr.direction = copy_direction::h2h; + attr.use_nontemporal = true; + + // copy own result from tmp to recv buffer + entry_factory::create( + sched, reduce_buf, seg_recv_buf + elem_offsets[comm_rank], elem_count, dtype, attr); + sched->add_barrier(); + + for (int idx = 1; idx < comm_size; idx++) { + int dst = (comm_rank + idx) % comm_size; + int src = (comm_rank - idx + comm_size) % comm_size; + + // send own result to other ranks + entry_factory::create( + sched, reduce_buf, elem_counts[comm_rank], dtype, dst, comm); + + // recv other's rank result into tmp buffer and copy to recv buffer + entry_factory::create(sched, + seg_tmp_buf + elem_offsets[src], + seg_recv_buf + elem_offsets[src], + elem_counts[src] * dtype_size, + src, + comm, + attr); + } + } + else { + CCL_CALL(ccl_coll_build_naive_allgatherv(sched, + seg_recv_buf, + elem_counts[comm_rank], + seg_recv_buf, + elem_counts.data(), + dtype, + comm)); + } + } return status; } @@ -500,163 +564,220 @@ ccl::status ccl_coll_build_ring_allreduce(ccl_sched* sched, return status; } -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) -ccl::status ccl_coll_build_gpu_allreduce(ccl_sched* sched, - ccl_buffer send_buf, - ccl_buffer recv_buf, - size_t count, - const ccl_datatype& dtype, - ccl::reduction op, - ccl_comm* comm) { - LOG_DEBUG("build gpu allreduce"); +ccl::status ccl_coll_build_topo_allreduce(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t count, + const ccl_datatype& dtype, + ccl::reduction op, + ccl_comm* comm) { + LOG_DEBUG("build topo allreduce"); - const std::vector in_buffers{ + std::vector in_buffers{ { send_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 { recv_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 1 }; - ccl_coll_entry_param barrier_param{}; - barrier_param.ctype = ccl_coll_barrier; - barrier_param.comm = comm; - barrier_param.hint_algo.barrier = ccl_coll_barrier_ring; + size_t ipc_event_count{}; + size_t max_ipc_event_count{ 6 }; + ze_event_pool_handle_t ipc_event_pool{}; + if (ccl::global_data::env().enable_ze_barrier) { + ipc_event_pool = sched->get_memory().ipc_event_pool_manager.create(max_ipc_event_count); + in_buffers.push_back({ static_cast(ipc_event_pool), ccl::ze::ipc_mem_type::pool }); + } + + ccl_comm* pair_comm = comm->get_pair_comm().get(); + ccl_comm* even_comm = comm->get_even_comm().get(); + ccl_comm* node_comm = comm->get_node_comm().get(); + ccl_comm* r2r_comm = comm->get_r2r_comm().get(); + + int comm_size = comm->size(); + int even_comm_size = even_comm->size(); + int node_comm_size = node_comm->size(); - ccl_comm* pair_comm = comm->get_host_comm()->get_pair_comm().get()->get_ccl_comm().get(); - ccl_comm* even_comm = comm->get_host_comm()->get_even_comm().get()->get_ccl_comm().get(); - ccl_comm* node_comm = comm->get_host_comm()->get_node_comm().get()->get_ccl_comm().get(); - ccl_comm* r2r_comm = comm->get_host_comm()->get_r2r_comm().get()->get_ccl_comm().get(); + bool is_single_node = (comm_size == node_comm_size); + bool is_single_card = (comm_size == 2) && is_single_node; + bool is_multi_card = (even_comm_size > 1); - int skip_rank = -1; - if (ccl::global_data::env().enable_kernel_1s_ipc_wa) { + size_t recv_buf_idx = 1; + + int skip_rank = ccl_comm::invalid_rank; + if (ccl::global_data::env().enable_kernel_1s_ipc_wa && is_single_card) { skip_rank = ccl::global_data::env().kernel_1s_lead; } - if (sched->coll_attr.to_cache) { - sched->set_entry_exec_mode(ccl_sched_entry_exec_once); - entry_factory::make_entry( - sched, node_comm, in_buffers, skip_rank); - sched->add_barrier(); - sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); + ccl::add_handle_exchange( + sched, node_comm, in_buffers, skip_rank, ipc_event_pool, ipc_event_count++); - // TODO: no need barrier for the first iteration where ze_handle_exchange_entry exists - // TODO: think about the right way - coll_entry_helper::add_coll_entry(sched, barrier_param); - } - else { - entry_factory::make_entry( - sched, node_comm, in_buffers, skip_rank); - } + CCL_THROW_IF_NOT(comm_size % 2 == 0, "unexpected comm_size ", comm_size); + CCL_THROW_IF_NOT(node_comm_size % 2 == 0, "unexpected node_comm_size ", node_comm_size); - sched->add_barrier(); + bool use_single_list = sched->enable_ze_single_list(); - if (comm->size() == 4) { - LOG_DEBUG("node_comm: id: ", - node_comm->id(), - ", size: ", - node_comm->size(), - ", rank: ", - node_comm->rank()); - - if (node_comm->size() == 2) { - LOG_DEBUG("r2r_comm: id: ", - r2r_comm->id(), - ", size: ", - r2r_comm->size(), - ", rank: ", - r2r_comm->rank()); - - if (node_comm->rank() == ccl::global_data::env().kernel_1s_lead) { - entry_factory::make_entry( - sched, send_buf, recv_buf, count, dtype, op, node_comm->rank(), node_comm); - sched->add_barrier(); - ccl_buffer host_buf = sched->alloc_buffer(count * dtype.size()); - entry_factory::make_entry( - sched, recv_buf, host_buf, count, dtype, copy_attr(copy_direction::d2h)); - sched->add_barrier(); - ccl_coll_build_allreduce(sched, host_buf, host_buf, count, dtype, op, r2r_comm); - sched->add_barrier(); - entry_factory::make_entry( - sched, host_buf, recv_buf, count, dtype, copy_attr(copy_direction::h2d)); + if (pair_comm->rank() == ccl::global_data::env().kernel_1s_lead) { + std::vector wait_events; + if (is_single_card) { + LOG_DEBUG("topo/scale_up/intra: use ze_onesided_allreduce"); + auto entry = entry_factory::create( + sched, send_buf, recv_buf, count, dtype, op, pair_comm, wait_events); + wait_events.push_back(entry->entry_event); + } + else { + LOG_DEBUG("topo/scale_up/intra: use ze_onesided_reduce"); + auto entry = entry_factory::create(sched, + send_buf, + recv_buf, + count, + dtype, + op, + pair_comm->rank(), + pair_comm, + wait_events); + wait_events.push_back(entry->entry_event); + } + sched->add_barrier(); + + size_t main_block_count = count / even_comm_size; + size_t block_count = main_block_count; + if (even_comm->rank() == even_comm_size - 1) { + block_count += count % even_comm_size; + } + + if (is_multi_card) { + auto barrier_event = ccl::add_comm_barrier( + sched, even_comm, wait_events, ipc_event_pool, ipc_event_count++); + wait_events.push_back(barrier_event); + + if (is_single_node) { + LOG_DEBUG("topo/scale_up/inter: use ze_a2a_allreduce"); + auto entry = entry_factory::create(sched, + recv_buf, + recv_buf, + count, + dtype, + op, + even_comm, + wait_events, + recv_buf_idx); + wait_events.push_back(entry->entry_event); sched->add_barrier(); - entry_factory::make_entry( - sched, - recv_buf, - ccl_buffer(), - count, - dtype, - copy_attr((node_comm->rank() + 1) % node_comm->size(), 1, copy_direction::d2d)); + + auto barrier_event = ccl::add_comm_barrier( + sched, even_comm, wait_events, ipc_event_pool, ipc_event_count++); + wait_events.push_back(barrier_event); + } + else { + size_t offset_bytes = main_block_count * even_comm->rank() * dtype.size(); + ccl_buffer partial_recv_buf = recv_buf + offset_bytes; + LOG_DEBUG("topo/scale_up/inter: use ze_a2a_reduce_scatter_entry"); + std::vector block_counts(even_comm->size(), main_block_count); + block_counts.back() = block_count; + auto entry = entry_factory::create(sched, + recv_buf, + partial_recv_buf, + block_counts.data(), + dtype, + op, + even_comm, + wait_events, + recv_buf_idx); + wait_events.push_back(entry->entry_event); sched->add_barrier(); + + auto barrier_event = ccl::add_comm_barrier( + sched, even_comm, wait_events, ipc_event_pool, ipc_event_count++); + wait_events.push_back(barrier_event); } - barrier_param.comm = comm; - coll_entry_helper::add_coll_entry(sched, barrier_param); } - else if (node_comm->size() == 4) { - LOG_DEBUG("pair_comm: id: ", - pair_comm->id(), - ", size: ", - pair_comm->size(), - ", rank: ", - pair_comm->rank()); - - LOG_DEBUG("even_comm: id: ", - even_comm->id(), - ", size: ", - even_comm->size(), - ", rank: ", - even_comm->rank()); - - if (pair_comm->rank() == ccl::global_data::env().kernel_1s_lead) { - entry_factory::make_entry( - sched, send_buf, recv_buf, count, dtype, op, pair_comm->rank(), pair_comm); - sched->add_barrier(); - barrier_param.comm = even_comm; - coll_entry_helper::add_coll_entry(sched, barrier_param); - sched->add_barrier(); + if (!is_single_node && block_count) { + LOG_DEBUG("topo/scale_out: use host_allreduce"); + ccl::alloc_param alloc_param( + block_count * dtype.size(), ccl::buffer_type::regular, ccl::buffer_place::host); + ccl_buffer host_buf = sched->alloc_buffer(alloc_param); + size_t offset_bytes = main_block_count * even_comm->rank() * dtype.size(); + ccl_buffer partial_recv_buf = recv_buf + offset_bytes; + auto entry = entry_factory::create(sched, + partial_recv_buf, + host_buf, + block_count, + dtype, + copy_attr(copy_direction::d2h), + wait_events); + wait_events.push_back(entry->entry_event); + sched->add_barrier(); - if (even_comm->rank() == ccl::global_data::env().kernel_1s_lead) { - entry_factory::make_entry( - sched, recv_buf, recv_buf, count, dtype, op, even_comm); - sched->add_barrier(); - } + if (use_single_list) { + ccl::add_wait_events(sched, wait_events); } - barrier_param.comm = comm; - coll_entry_helper::add_coll_entry(sched, barrier_param); + ccl_coll_build_allreduce(sched, host_buf, host_buf, block_count, dtype, op, r2r_comm); sched->add_barrier(); - if (pair_comm->rank() != ccl::global_data::env().kernel_1s_lead) { - entry_factory::make_entry( - sched, - ccl_buffer(), - recv_buf, - count, - dtype, - copy_attr((pair_comm->rank() + 1) % pair_comm->size(), - 1, - copy_direction::d2d, - pair_comm)); - sched->add_barrier(); + if (use_single_list) { + auto signal_event = ccl::add_signal_event(sched); + wait_events.push_back(signal_event); } + + entry = entry_factory::create(sched, + host_buf, + partial_recv_buf, + block_count, + dtype, + copy_attr(copy_direction::h2d), + wait_events); + wait_events.push_back(entry->entry_event); + sched->add_barrier(); } - else { - CCL_THROW("unexpected node_comm size: ", node_comm->size()); + + if (is_multi_card && !is_single_node) { + LOG_DEBUG("topo/scale_up/inter: use ze_a2a_allgatherv"); + std::vector recv_counts(even_comm_size, main_block_count); + recv_counts.at(even_comm->rank()) = block_count; + auto entry = entry_factory::create(sched, + recv_buf, + block_count, + recv_buf, + recv_counts.data(), + dtype, + even_comm, + wait_events, + recv_buf_idx); + wait_events.push_back(entry->entry_event); + sched->add_barrier(); + + auto barrier_event = ccl::add_comm_barrier( + sched, even_comm, wait_events, ipc_event_pool, ipc_event_count++); + wait_events.push_back(barrier_event); } - } - else if (comm->size() == 2) { - if (comm->rank() == ccl::global_data::env().kernel_1s_lead) { - entry_factory::make_entry( - sched, send_buf, recv_buf, count, dtype, op, comm); + + if (!is_single_card) { + LOG_DEBUG("topo/scale_up/intra: use ze_onesided_bcast"); + int peer_rank = (pair_comm->rank() + 1) % pair_comm->size(); + auto entry = entry_factory::create( + sched, + recv_buf, + ccl_buffer(), + count, + dtype, + copy_attr(peer_rank, recv_buf_idx, copy_direction::d2d, pair_comm), + wait_events); + wait_events.push_back(entry->entry_event); sched->add_barrier(); } - barrier_param.comm = comm; - coll_entry_helper::add_coll_entry(sched, barrier_param); - } - else { - CCL_THROW("unexpected comm size: ", comm->size()); } + ccl::add_comm_barrier(sched, pair_comm, ipc_event_pool, ipc_event_count++); + + CCL_THROW_IF_NOT(ipc_event_count <= max_ipc_event_count, + "unexpected ipc_event_count ", + ipc_event_count, + ", expected max ", + max_ipc_event_count); + return ccl::status::success; } -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE diff --git a/src/coll/algorithms/allreduce/allreduce_2d.cpp b/src/coll/algorithms/allreduce/allreduce_2d.cpp index 54e5b2719..2457bcee9 100644 --- a/src/coll/algorithms/allreduce/allreduce_2d.cpp +++ b/src/coll/algorithms/allreduce/allreduce_2d.cpp @@ -23,35 +23,30 @@ ccl_allreduce_2d_builder::ccl_allreduce_2d_builder(size_t base_size, ccl_comm* comm) { parent_comm = comm; - size_t vector_size = comm->size(); - std::vector first_dim_colors(vector_size), second_dim_colors(vector_size); + int first_dim_color, second_dim_color; - for (size_t idx = 0; idx < vector_size; idx++) { - if (switch_dims) { - first_dim_colors[idx] = idx / base_size; - second_dim_colors[idx] = idx % base_size; - } - else { - first_dim_colors[idx] = idx % base_size; - second_dim_colors[idx] = idx / base_size; - } + if (switch_dims) { + first_dim_color = comm->rank() / base_size; + second_dim_color = comm->rank() % base_size; + } + else { + first_dim_color = comm->rank() % base_size; + second_dim_color = comm->rank() / base_size; } - first_dim_comm = std::shared_ptr(ccl_comm::create_with_colors( - first_dim_colors, ccl::global_data::get().comm_ids.get(), comm, true /*share_resources*/)); + first_dim_comm = std::shared_ptr(comm->create_with_color( + first_dim_color, ccl::global_data::get().comm_ids.get(), true /*share_resources*/)); - second_dim_comm = std::shared_ptr(ccl_comm::create_with_colors( - second_dim_colors, ccl::global_data::get().comm_ids.get(), comm, true /*share_resources*/)); + second_dim_comm = std::shared_ptr(comm->create_with_color( + second_dim_color, ccl::global_data::get().comm_ids.get(), true /*share_resources*/)); if (comm->rank() == 0) { std::string first_dim_ranks, second_dim_ranks; for (int idx = 0; idx < first_dim_comm->size(); idx++) { - first_dim_ranks += - ((idx) ? " " : "") + std::to_string(first_dim_comm->get_global_rank(idx)); + first_dim_ranks += ((idx) ? " " : "") + std::to_string(idx); } for (int idx = 0; idx < second_dim_comm->size(); idx++) { - second_dim_ranks += - ((idx) ? " " : "") + std::to_string(second_dim_comm->get_global_rank(idx)); + second_dim_ranks += ((idx) ? " " : "") + std::to_string(idx); } std::stringstream ss; @@ -79,8 +74,8 @@ static void ccl_allreduce_2d_add_allreduce_allgather(ccl_sched* sched, ccl_comm* comm, size_t chunk_idx, size_t chunk_count) { - ccl_comm* first_dim_comm = comm->allreduce_2d_builder->get_first_dim_comm(); - ccl_comm* second_dim_comm = comm->allreduce_2d_builder->get_second_dim_comm(); + ccl_comm* first_dim_comm = comm->get_allreduce_2d_builder()->get_first_dim_comm(); + ccl_comm* second_dim_comm = comm->get_allreduce_2d_builder()->get_second_dim_comm(); size_t dtype_size = dtype.size(); size_t main_chunk_size = count / chunk_count; @@ -96,7 +91,7 @@ static void ccl_allreduce_2d_add_allreduce_allgather(ccl_sched* sched, if (ar_count) { /* TODO: add second level selection to distinguish high and low level algorithms */ ccl_buffer ar_buf = rbuf + first_dim_comm->rank() * main_block_count * dtype_size; - ccl_coll_build_starlike_allreduce( + ccl_coll_build_nreduce_allreduce( sched, ar_buf, ar_buf, ar_count, dtype, op, second_dim_comm); sched->add_barrier(); } @@ -116,7 +111,7 @@ static void ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather(ccl_sched* s ccl_comm* comm, size_t chunk_idx, size_t chunk_count) { - ccl_comm* first_dim_comm = comm->allreduce_2d_builder->get_first_dim_comm(); + ccl_comm* first_dim_comm = comm->get_allreduce_2d_builder()->get_first_dim_comm(); size_t dtype_size = dtype.size(); size_t main_chunk_size = count / chunk_count; @@ -133,7 +128,7 @@ static void ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather(ccl_sched* s sched, send_buf, recv_buf, count, dtype, op, comm, chunk_idx, chunk_count); } else { - entry_factory::make_entry( + entry_factory::create( sched, chunk_idx, [send_buf, recv_buf, count, &dtype, op, comm, chunk_idx, chunk_count](ccl_sched* s) { @@ -142,7 +137,7 @@ static void ccl_allreduce_2d_add_reduce_scatter_allreduce_allgather(ccl_sched* s }, "AR_AG"); - entry_factory::make_entry( + entry_factory::create( sched, chunk_idx + 1, [send_buf, recv_buf, count, &dtype, op, comm, chunk_idx, chunk_count](ccl_sched* s) { diff --git a/src/coll/algorithms/allreduce/allreduce_rma.cpp b/src/coll/algorithms/allreduce/allreduce_rma.cpp index a91357f6d..da418f555 100644 --- a/src/coll/algorithms/allreduce/allreduce_rma.cpp +++ b/src/coll/algorithms/allreduce/allreduce_rma.cpp @@ -146,7 +146,7 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, if (comm_size == 1) { if (!inplace) { - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } return ccl::status::success; @@ -160,25 +160,25 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, sched->set_entry_exec_mode(ccl_sched_entry_exec_once); - entry_factory::make_entry( + entry_factory::create( sched, 2 * comm_size * sizeof(uint64_t), ccl_buffer(ar_handler->sync_flags, 2 * comm_size * sizeof(uint64_t)), &ar_handler->sync_flags_mr, comm); - entry_factory::make_entry( + entry_factory::create( sched, sizeof(uint64_t), ccl_buffer((void*)&ar_handler->sync_flag, sizeof(uint64_t)), &ar_handler->sync_flag_mr, comm); - entry_factory::make_entry( + entry_factory::create( sched, sizeof(uint64_t), ccl_buffer((void*)&ar_handler->dst_ready_flag, sizeof(uint64_t)), &ar_handler->dst_ready_flag_mr, comm); - entry_factory::make_entry( + entry_factory::create( sched, sizeof(uint64_t), ccl_buffer(&ar_handler->dst_ready_value, sizeof(uint64_t)), @@ -187,13 +187,13 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, if (inplace) { tmp_buf = sched->alloc_buffer(count * dtype_size); - entry_factory::make_entry( + entry_factory::create( sched, count * dtype_size, tmp_buf, &ar_handler->tmp_buf_mr, comm); } else - entry_factory::make_entry( + entry_factory::create( sched, count * dtype_size, send_buf, &ar_handler->send_buf_mr, comm); - entry_factory::make_entry( + entry_factory::create( sched, count * dtype_size, recv_buf, &ar_handler->recv_buf_mr, comm); sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); @@ -205,24 +205,23 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, ar_handler->src_peer = (comm_size + rank - 1) % comm_size; ar_handler->dst_peer = (comm_size + rank + 1) % comm_size; - entry_factory::make_entry( - sched, rma_ring_allreduce_reset_sync_flag, ar_handler); + entry_factory::create(sched, rma_ring_allreduce_reset_sync_flag, ar_handler); sched->add_barrier(); sched->set_entry_exec_mode(ccl_sched_entry_exec_once); if (inplace) { - send_entry* e = entry_factory::make_entry( - sched, - ccl_buffer(&ar_handler->tmp_buf_mr, sizeof(atl_mr_t)), - sizeof(atl_mr_t), - ccl_datatype_int8, - ar_handler->src_peer, - comm); + send_entry* e = + entry_factory::create(sched, + ccl_buffer(&ar_handler->tmp_buf_mr, sizeof(atl_mr_t)), + sizeof(atl_mr_t), + ccl_datatype_int8, + ar_handler->src_peer, + comm); e->set_field_fn(rma_ring_allreduce_get_tmp_buf_mr, ar_handler); } else { - send_entry* e = entry_factory::make_entry( + send_entry* e = entry_factory::create( sched, ccl_buffer(&ar_handler->recv_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), @@ -231,39 +230,37 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, comm); e->set_field_fn(rma_ring_allreduce_get_recv_buf_mr, ar_handler); } - send_entry* e = entry_factory::make_entry( - sched, - ccl_buffer(&ar_handler->recv_buf_mr, sizeof(atl_mr_t)), - sizeof(atl_mr_t), - ccl_datatype_int8, - ar_handler->src_peer, - comm); + send_entry* e = + entry_factory::create(sched, + ccl_buffer(&ar_handler->recv_buf_mr, sizeof(atl_mr_t)), + sizeof(atl_mr_t), + ccl_datatype_int8, + ar_handler->src_peer, + comm); e->set_field_fn(rma_ring_allreduce_get_recv_buf_mr, ar_handler); - e = entry_factory::make_entry( - sched, - ccl_buffer(&ar_handler->sync_flag_mr, sizeof(atl_mr_t)), - sizeof(atl_mr_t), - ccl_datatype_int8, - ar_handler->src_peer, - comm); + e = entry_factory::create(sched, + ccl_buffer(&ar_handler->sync_flag_mr, sizeof(atl_mr_t)), + sizeof(atl_mr_t), + ccl_datatype_int8, + ar_handler->src_peer, + comm); e->set_field_fn(rma_ring_allreduce_get_sync_flag_mr, ar_handler); - entry_factory::make_entry( + entry_factory::create( sched, ccl_buffer(&ar_handler->remote_rs_dst_buf_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), ccl_datatype_int8, ar_handler->dst_peer, comm); - entry_factory::make_entry( - sched, - ccl_buffer(&ar_handler->remote_recv_buf_mr, sizeof(atl_mr_t)), - sizeof(atl_mr_t), - ccl_datatype_int8, - ar_handler->dst_peer, - comm); - entry_factory::make_entry( + entry_factory::create(sched, + ccl_buffer(&ar_handler->remote_recv_buf_mr, sizeof(atl_mr_t)), + sizeof(atl_mr_t), + ccl_datatype_int8, + ar_handler->dst_peer, + comm); + entry_factory::create( sched, ccl_buffer(&ar_handler->remote_sync_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), @@ -272,7 +269,7 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, comm); if (ar_handler->wait_dst) { - send_entry* e = entry_factory::make_entry( + send_entry* e = entry_factory::create( sched, ccl_buffer(ar_handler->dst_ready_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), @@ -281,7 +278,7 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, comm); e->set_field_fn(rma_ring_allreduce_get_dst_ready_flag_mr, ar_handler); - entry_factory::make_entry( + entry_factory::create( sched, ccl_buffer(&ar_handler->remote_dst_ready_flag_mr, sizeof(atl_mr_t)), sizeof(atl_mr_t), @@ -296,7 +293,7 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, if (ar_handler->wait_dst) { /* let src side to know that this rank (i.e. dst for src rank) is ready for write ops */ ar_handler->dst_ready_value = 1; - write_entry* entry = entry_factory::make_entry( + write_entry* entry = entry_factory::create( sched, ccl_buffer(&ar_handler->dst_ready_value, sizeof(uint64_t)), (atl_mr_t*)nullptr, /* src_mr */ @@ -312,11 +309,11 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, rma_ring_allreduce_get_remote_dst_ready_flag_mr, ar_handler); /* wait when dst side will be ready for write ops */ - entry_factory::make_entry( + entry_factory::create( sched, &(ar_handler->dst_ready_flag), 1, ccl_condition_equal); /* reset dst_ready_flag for next allreduce call */ - entry_factory::make_entry( + entry_factory::create( sched, rma_ring_allreduce_reset_dst_ready_flag, ar_handler); } @@ -337,15 +334,15 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, else src_buf = (idx == 0) ? send_buf : recv_buf; - write_entry* entry = entry_factory::make_entry(sched, - src_buf + buf_offset, - (atl_mr_t*)nullptr, /* src_mr */ - block_count, - dtype, - ar_handler->dst_peer, - (atl_mr_t*)nullptr, /* dst_mr */ - buf_offset, - comm); + write_entry* entry = entry_factory::create(sched, + src_buf + buf_offset, + (atl_mr_t*)nullptr, /* src_mr */ + block_count, + dtype, + ar_handler->dst_peer, + (atl_mr_t*)nullptr, /* dst_mr */ + buf_offset, + comm); entry->set_field_fn( (inplace) ? rma_ring_allreduce_get_recv_buf_mr : ((idx == 0) ? rma_ring_allreduce_get_send_buf_mr @@ -354,10 +351,10 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, entry->set_field_fn( rma_ring_allreduce_get_remote_rs_dst_buf_mr, ar_handler); - if (block_count * dtype.size() > atl_wrapper::attr.out.max_order_waw_size) + if (block_count * dtype.size() > atl_base_comm::attr.out.max_order_waw_size) sched->add_barrier(); - entry = entry_factory::make_entry( + entry = entry_factory::create( sched, ccl_buffer(&ar_handler->sync_flags[idx], sizeof(uint64_t)), (atl_mr_t*)nullptr, /* src_mr */ @@ -378,18 +375,18 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, block_count += count % comm_size; buf_offset = main_block_count * dtype_size * block_idx; - entry_factory::make_entry( + entry_factory::create( sched, &(ar_handler->sync_flag), (idx + 1), ccl_condition_greater_or_equal); ccl_buffer reduce_in_buf = (inplace) ? tmp_buf : send_buf; ccl_buffer reduce_inout_buf = recv_buf; - entry_factory::make_entry(sched, - reduce_in_buf + buf_offset, - block_count, - reduce_inout_buf + buf_offset, - nullptr, - dtype, - op); + entry_factory::create(sched, + reduce_in_buf + buf_offset, + block_count, + reduce_inout_buf + buf_offset, + nullptr, + dtype, + op); } /* allgather */ @@ -401,24 +398,24 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, buf_offset = main_block_count * dtype_size * block_idx; ccl_buffer src_buf = recv_buf; - write_entry* entry = entry_factory::make_entry(sched, - src_buf + buf_offset, - (atl_mr_t*)nullptr, /* src_mr */ - block_count, - dtype, - ar_handler->dst_peer, - (atl_mr_t*)nullptr, /* dst_mr */ - buf_offset, - comm); + write_entry* entry = entry_factory::create(sched, + src_buf + buf_offset, + (atl_mr_t*)nullptr, /* src_mr */ + block_count, + dtype, + ar_handler->dst_peer, + (atl_mr_t*)nullptr, /* dst_mr */ + buf_offset, + comm); entry->set_field_fn(rma_ring_allreduce_get_recv_buf_mr, ar_handler); entry->set_field_fn(rma_ring_allreduce_get_remote_recv_buf_mr, ar_handler); - if (block_count * dtype.size() > atl_wrapper::attr.out.max_order_waw_size) + if (block_count * dtype.size() > atl_base_comm::attr.out.max_order_waw_size) sched->add_barrier(); - entry = entry_factory::make_entry( + entry = entry_factory::create( sched, ccl_buffer(&ar_handler->sync_flags[flag_idx_offset + idx], sizeof(uint64_t)), (atl_mr_t*)nullptr, /* src_mr */ @@ -435,10 +432,10 @@ ccl::status ccl_coll_build_ring_rma_allreduce(ccl_sched* sched, block_idx = (block_idx + comm_size - 1) % comm_size; - entry_factory::make_entry(sched, - &(ar_handler->sync_flag), - (flag_idx_offset + idx + 1), - ccl_condition_greater_or_equal); + entry_factory::create(sched, + &(ar_handler->sync_flag), + (flag_idx_offset + idx + 1), + ccl_condition_greater_or_equal); } return status; diff --git a/src/coll/algorithms/allreduce/allreduce_rma.hpp b/src/coll/algorithms/allreduce/allreduce_rma.hpp index 76e2075c8..80613a4c9 100644 --- a/src/coll/algorithms/allreduce/allreduce_rma.hpp +++ b/src/coll/algorithms/allreduce/allreduce_rma.hpp @@ -15,8 +15,6 @@ */ #pragma once -#include "atl/atl.h" - typedef struct { int wait_dst; diff --git a/src/coll/algorithms/alltoall.cpp b/src/coll/algorithms/alltoall.cpp index 2bd43f5cf..acfa2e4b7 100644 --- a/src/coll/algorithms/alltoall.cpp +++ b/src/coll/algorithms/alltoall.cpp @@ -24,6 +24,6 @@ ccl::status ccl_coll_build_direct_alltoall(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct alltoall"); - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype, comm); + entry_factory::create(sched, send_buf, recv_buf, count, dtype, comm); return ccl::status::success; } diff --git a/src/coll/algorithms/alltoallv.cpp b/src/coll/algorithms/alltoallv.cpp index caf063b4c..32eda791c 100644 --- a/src/coll/algorithms/alltoallv.cpp +++ b/src/coll/algorithms/alltoallv.cpp @@ -35,7 +35,7 @@ ccl::status ccl_coll_build_direct_alltoallv(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct alltoallv"); - entry_factory::make_entry( + entry_factory::create( sched, send_buf, send_counts, recv_buf, recv_counts, dtype, comm); return ccl::status::success; } @@ -157,17 +157,17 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, if (!inplace && send_counts[comm_rank] && recv_counts[comm_rank]) { size_t sched_idx = (2 * comm_rank) % sched_count; - entry_factory::make_entry(scheds[sched_idx], - ccl_buffer(coll_param.get_send_buf_ptr(), - total_send_bytes, - send_offsets[comm_rank], - ccl_buffer_type::INDIRECT), - ccl_buffer(coll_param.get_recv_buf_ptr(), - total_recv_bytes, - recv_offsets[comm_rank], - ccl_buffer_type::INDIRECT), - send_counts[comm_rank], - dtype); + entry_factory::create(scheds[sched_idx], + ccl_buffer(coll_param.get_send_buf_ptr(), + total_send_bytes, + send_offsets[comm_rank], + ccl_buffer_type::INDIRECT), + ccl_buffer(coll_param.get_recv_buf_ptr(), + total_recv_bytes, + recv_offsets[comm_rank], + ccl_buffer_type::INDIRECT), + send_counts[comm_rank], + dtype); } for (int idx = 0; idx < comm_size; idx++) { @@ -179,7 +179,8 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, ccl_buffer recv_buf; if (inplace) - recv_buf = scheds[sched_idx]->alloc_buffer(recv_counts[idx] * dtype_size); + recv_buf = scheds[sched_idx]->alloc_buffer( + { recv_counts[idx] * dtype_size, coll_param.get_recv_buf() }); else recv_buf = ccl_buffer(coll_param.get_recv_buf_ptr(), total_recv_bytes, @@ -202,14 +203,14 @@ ccl::status ccl_coll_build_naive_alltoallv(ccl_master_sched* main_sched, if (inplace) { scheds[sched_idx]->add_barrier(); - entry_factory::make_entry(scheds[sched_idx], - recv_buf, - ccl_buffer(coll_param.get_recv_buf_ptr(), - total_recv_bytes, - recv_offsets[idx], - ccl_buffer_type::INDIRECT), - recv_counts[idx], - dtype); + entry_factory::create(scheds[sched_idx], + recv_buf, + ccl_buffer(coll_param.get_recv_buf_ptr(), + total_recv_bytes, + recv_offsets[idx], + ccl_buffer_type::INDIRECT), + recv_counts[idx], + dtype); scheds[sched_idx]->add_barrier(); } } @@ -252,17 +253,17 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, if (!inplace && send_counts[comm_rank] && recv_counts[comm_rank]) { size_t sched_idx = (2 * comm_rank) % sched_count; - entry_factory::make_entry(scheds[sched_idx], - ccl_buffer(coll_param.get_send_buf_ptr(), - total_send_bytes, - send_offsets[comm_rank], - ccl_buffer_type::INDIRECT), - ccl_buffer(coll_param.get_recv_buf_ptr(), - total_recv_bytes, - recv_offsets[comm_rank], - ccl_buffer_type::INDIRECT), - send_counts[comm_rank], - dtype); + entry_factory::create(scheds[sched_idx], + ccl_buffer(coll_param.get_send_buf_ptr(), + total_send_bytes, + send_offsets[comm_rank], + ccl_buffer_type::INDIRECT), + ccl_buffer(coll_param.get_recv_buf_ptr(), + total_recv_bytes, + recv_offsets[comm_rank], + ccl_buffer_type::INDIRECT), + send_counts[comm_rank], + dtype); } for (int idx = 0; idx < comm_size; idx++) { @@ -276,7 +277,8 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, ccl_buffer recv_buf; if (inplace) { - recv_buf = scheds[sched_idx]->alloc_buffer(recv_counts[src] * dtype_size); + recv_buf = scheds[sched_idx]->alloc_buffer( + { recv_counts[src] * dtype_size, coll_param.get_recv_buf() }); recv_bufs[src] = recv_buf; } else @@ -321,14 +323,14 @@ ccl::status ccl_coll_build_scatter_alltoallv(ccl_master_sched* main_sched, size_t sched_idx = (comm_rank + idx) % sched_count; - entry_factory::make_entry(scheds[sched_idx], - recv_bufs[idx], - ccl_buffer(coll_param.get_recv_buf_ptr(), - total_recv_bytes, - recv_offsets[idx], - ccl_buffer_type::INDIRECT), - recv_counts[idx], - dtype); + entry_factory::create(scheds[sched_idx], + recv_bufs[idx], + ccl_buffer(coll_param.get_recv_buf_ptr(), + total_recv_bytes, + recv_offsets[idx], + ccl_buffer_type::INDIRECT), + recv_counts[idx], + dtype); } return ccl::status::success; @@ -378,13 +380,13 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche std::vector send_scheds(sched_count); for (size_t idx = 0; idx < sched_count; idx++) { - auto recv_sched = entry_factory::make_entry( + auto recv_sched = entry_factory::create( scheds[idx], 0, [](ccl_sched* s) {}, "A2AV_RECV") ->get_subsched(); recv_scheds[idx] = recv_sched; - auto send_sched = entry_factory::make_entry( + auto send_sched = entry_factory::create( scheds[idx], 0, [](ccl_sched* s) {}, "A2AV_SEND") ->get_subsched(); @@ -393,17 +395,17 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche if (!inplace && send_counts[comm_rank] && recv_counts[comm_rank]) { size_t sched_idx = (2 * comm_rank) % sched_count; - entry_factory::make_entry(recv_scheds[sched_idx], - ccl_buffer(coll_param.get_send_buf_ptr(), - total_send_bytes, - send_offsets[comm_rank], - ccl_buffer_type::INDIRECT), - ccl_buffer(coll_param.get_recv_buf_ptr(), - total_recv_bytes, - recv_offsets[comm_rank], - ccl_buffer_type::INDIRECT), - send_counts[comm_rank], - dtype); + entry_factory::create(recv_scheds[sched_idx], + ccl_buffer(coll_param.get_send_buf_ptr(), + total_send_bytes, + send_offsets[comm_rank], + ccl_buffer_type::INDIRECT), + ccl_buffer(coll_param.get_recv_buf_ptr(), + total_recv_bytes, + recv_offsets[comm_rank], + ccl_buffer_type::INDIRECT), + send_counts[comm_rank], + dtype); } for (int idx = 0; idx < comm_size; idx++) { @@ -420,7 +422,8 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche ccl_buffer recv_buf; if (inplace) { - recv_buf = sched->alloc_buffer(recv_counts[src] * dtype_size); + recv_buf = + sched->alloc_buffer({ recv_counts[src] * dtype_size, coll_param.get_recv_buf() }); recv_bufs[src] = recv_buf; } else @@ -465,14 +468,14 @@ ccl::status ccl_coll_build_scatter_barrier_alltoallv(ccl_master_sched* main_sche size_t sched_idx = (comm_rank + idx) % sched_count; - entry_factory::make_entry(scheds[sched_idx], - recv_bufs[idx], - ccl_buffer(coll_param.get_recv_buf_ptr(), - total_recv_bytes, - recv_offsets[idx], - ccl_buffer_type::INDIRECT), - recv_counts[idx], - dtype); + entry_factory::create(scheds[sched_idx], + recv_bufs[idx], + ccl_buffer(coll_param.get_recv_buf_ptr(), + total_recv_bytes, + recv_offsets[idx], + ccl_buffer_type::INDIRECT), + recv_counts[idx], + dtype); } return ccl::status::success; diff --git a/src/coll/algorithms/barrier.cpp b/src/coll/algorithms/barrier.cpp index 5aa05e094..d04290a45 100644 --- a/src/coll/algorithms/barrier.cpp +++ b/src/coll/algorithms/barrier.cpp @@ -25,7 +25,7 @@ ccl::status ccl_coll_build_direct_barrier(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct barrier"); - entry_factory::make_entry(sched, comm); + entry_factory::create(sched, comm); return ccl::status::success; } @@ -44,8 +44,8 @@ ccl::status ccl_coll_build_dissemination_barrier(ccl_sched* sched, ccl_comm* com while (mask < size) { dst = (rank + mask) % size; src = (rank - mask + size) % size; - entry_factory::make_entry(sched, ccl_buffer(), 0, ccl_datatype_int8, dst, comm); - entry_factory::make_entry(sched, ccl_buffer(), 0, ccl_datatype_int8, src, comm); + entry_factory::create(sched, ccl_buffer(), 0, ccl_datatype_int8, dst, comm); + entry_factory::create(sched, ccl_buffer(), 0, ccl_datatype_int8, src, comm); sched->add_barrier(); mask <<= 1; } diff --git a/src/coll/algorithms/bcast.cpp b/src/coll/algorithms/bcast.cpp index c4ba99976..2183bf71c 100644 --- a/src/coll/algorithms/bcast.cpp +++ b/src/coll/algorithms/bcast.cpp @@ -34,7 +34,7 @@ ccl::status ccl_coll_build_direct_bcast(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct bcast"); - entry_factory::make_entry(sched, buf, count, dtype, root, comm); + entry_factory::create(sched, buf, count, dtype, root, comm); return ccl::status::success; } @@ -58,12 +58,12 @@ ccl::status ccl_coll_build_naive_bcast(ccl_sched* sched, if (rank == root) { for (idx = 0; idx < comm_size; idx++) { if (idx != rank) { - entry_factory::make_entry(sched, buf, count, dtype, idx, comm); + entry_factory::create(sched, buf, count, dtype, idx, comm); } } } else { - entry_factory::make_entry(sched, buf, count, dtype, root, comm); + entry_factory::create(sched, buf, count, dtype, root, comm); } fn_exit: @@ -116,12 +116,12 @@ ccl::status ccl_coll_build_scatter_for_bcast(ccl_sched* sched, curr_size = recv_size; if (recv_size > 0) { - entry_factory::make_entry(sched, - tmp_buf + relative_rank * scatter_size, - recv_size, - ccl_datatype_int8, - src, - comm); + entry_factory::create(sched, + tmp_buf + relative_rank * scatter_size, + recv_size, + ccl_datatype_int8, + src, + comm); sched->add_barrier(); } break; @@ -145,13 +145,12 @@ ccl::status ccl_coll_build_scatter_for_bcast(ccl_sched* sched, if (dst >= comm_size) dst -= comm_size; - entry_factory::make_entry( - sched, - tmp_buf + scatter_size * (relative_rank + mask), - send_size, - ccl_datatype_int8, - dst, - comm); + entry_factory::create(sched, + tmp_buf + scatter_size * (relative_rank + mask), + send_size, + ccl_datatype_int8, + dst, + comm); sched->add_barrier(); curr_size -= send_size; } @@ -219,10 +218,10 @@ ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, if (right_count < 0) right_count = 0; right_disp = rel_j * scatter_size; - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf + right_disp, right_count, ccl_datatype_int8, right, comm); /* sendrecv, no barrier here */ - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf + left_disp, left_count, ccl_datatype_int8, left, comm); sched->add_barrier(); @@ -234,7 +233,7 @@ ccl::status ccl_coll_build_scatter_ring_allgather_bcast(ccl_sched* sched, return status; } -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) ccl::status ccl_coll_build_gpu_bcast(ccl_sched* sched, ccl_buffer buf, @@ -256,20 +255,20 @@ ccl::status ccl_coll_build_gpu_bcast(ccl_sched* sched, if (sched->coll_attr.to_cache) { sched->set_entry_exec_mode(ccl_sched_entry_exec_once); - entry_factory::make_entry(sched, comm, buffers); + entry_factory::create(sched, comm, buffers); sched->add_barrier(); sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); coll_entry_helper::add_coll_entry(sched, barrier_param); } else { - entry_factory::make_entry(sched, comm, buffers); + entry_factory::create(sched, comm, buffers); } sched->add_barrier(); if (comm->rank() != root) { - entry_factory::make_entry( + entry_factory::create( sched, ccl_buffer(), buf, count, dtype, copy_attr(root, 0, copy_direction::d2d)); sched->add_barrier(); } @@ -279,4 +278,4 @@ ccl::status ccl_coll_build_gpu_bcast(ccl_sched* sched, return ccl::status::success; } -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE diff --git a/src/coll/algorithms/double_tree_ops.cpp b/src/coll/algorithms/double_tree_ops.cpp index e124673f2..f9139033b 100644 --- a/src/coll/algorithms/double_tree_ops.cpp +++ b/src/coll/algorithms/double_tree_ops.cpp @@ -25,18 +25,18 @@ static void bcast_tree(const ccl_bin_tree& tree, ccl_comm* comm) { if (tree.parent() != -1) { LOG_DEBUG("recv from parent ", tree.parent()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.parent()), comm); sched->add_barrier(); } if (tree.left() != -1) { LOG_DEBUG("send to left ", tree.left()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.left()), comm); } if (tree.right() != -1) { LOG_DEBUG("send to right ", tree.right()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.right()), comm); } } @@ -50,34 +50,20 @@ static void reduce_tree(const ccl_bin_tree& tree, ccl_comm* comm) { if (tree.left() != -1) { LOG_DEBUG("recv_reduce left ", tree.left()); - entry_factory::make_entry(sched, - buffer, - count, - nullptr, - dtype, - reduction, - static_cast(tree.left()), - ccl_buffer(), - comm); + entry_factory::create( + sched, buffer, count, dtype, reduction, static_cast(tree.left()), comm); } if (tree.right() != -1) { LOG_DEBUG("recv_reduce right ", tree.right()); - entry_factory::make_entry(sched, - buffer, - count, - nullptr, - dtype, - reduction, - static_cast(tree.right()), - ccl_buffer(), - comm); + entry_factory::create( + sched, buffer, count, dtype, reduction, static_cast(tree.right()), comm); } if (tree.parent() != -1) { if (tree.left() != -1 || tree.right() != -1) { sched->add_barrier(); } LOG_DEBUG("send to parent ", tree.parent()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.parent()), comm); } } @@ -91,27 +77,13 @@ static void reduce_bcast_tree(const ccl_bin_tree& tree, ccl_comm* comm) { if (tree.left() != -1) { LOG_DEBUG("recv_reduce left ", tree.left()); - entry_factory::make_entry(sched, - buffer, - count, - nullptr, - dtype, - reduction, - static_cast(tree.left()), - ccl_buffer(), - comm); + entry_factory::create( + sched, buffer, count, dtype, reduction, static_cast(tree.left()), comm); } if (tree.right() != -1) { LOG_DEBUG("recv_reduce right ", tree.right()); - entry_factory::make_entry(sched, - buffer, - count, - nullptr, - dtype, - reduction, - static_cast(tree.right()), - ccl_buffer(), - comm); + entry_factory::create( + sched, buffer, count, dtype, reduction, static_cast(tree.right()), comm); } if (tree.parent() != -1) { if (tree.left() != -1 || tree.right() != -1) { @@ -119,11 +91,11 @@ static void reduce_bcast_tree(const ccl_bin_tree& tree, } LOG_DEBUG("send to parent ", tree.parent()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.parent()), comm); LOG_DEBUG("recv from parent ", tree.parent()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.parent()), comm); } @@ -133,12 +105,12 @@ static void reduce_bcast_tree(const ccl_bin_tree& tree, if (tree.left() != -1) { LOG_DEBUG("send to left ", tree.left()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.left()), comm); } if (tree.right() != -1) { LOG_DEBUG("send to right ", tree.right()); - entry_factory::make_entry( + entry_factory::create( sched, buffer, count, dtype, static_cast(tree.right()), comm); } } @@ -158,7 +130,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, if (coll_type != ccl_coll_bcast && send_buf != recv_buf) { LOG_DEBUG("out of place op"); - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } @@ -237,7 +209,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, switch (coll_type) { case ccl_coll_bcast: - entry_factory::make_entry( + entry_factory::create( sched, t1_op_id, [t1_work_buf, t1_work_count, &dtype, t1, comm](ccl_sched* s) { @@ -245,7 +217,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, }, "bcast_t1"); - entry_factory::make_entry( + entry_factory::create( sched, t2_op_id, [t2_work_buf, t2_work_count, &dtype, t2, comm](ccl_sched* s) { @@ -257,7 +229,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, case ccl_coll_reduce: { if (comm->rank() % 2 == 0) { //even ranks are leaves in T2, start schedule with T2 - entry_factory::make_entry( + entry_factory::create( sched, t2_op_id, [t2_work_buf, t2_work_count, &dtype, op, t2, comm](ccl_sched* s) { @@ -265,7 +237,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, }, "reduce_t2"); - entry_factory::make_entry( + entry_factory::create( sched, t1_op_id, [t1_work_buf, t1_work_count, &dtype, op, t1, comm](ccl_sched* s) { @@ -274,7 +246,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, "reduce_t1"); } else { - entry_factory::make_entry( + entry_factory::create( sched, t2_op_id, [t2_work_buf, t2_work_count, &dtype, op, t2, comm](ccl_sched* s) { @@ -282,7 +254,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, }, "reduce_t2"); - entry_factory::make_entry( + entry_factory::create( sched, t1_op_id, [t1_work_buf, t1_work_count, &dtype, op, t1, comm](ccl_sched* s) { @@ -296,7 +268,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, case ccl_coll_allreduce: { if (comm->rank() % 2 == 0) { //even ranks are leaves in T2, start schedule with T2 - entry_factory::make_entry( + entry_factory::create( sched, t2_op_id, [t2_work_buf, t2_work_count, &dtype, op, t2, comm](ccl_sched* s) { @@ -304,7 +276,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, }, "reduce_bcast_t2"); - entry_factory::make_entry( + entry_factory::create( sched, t1_op_id, [t1_work_buf, t1_work_count, &dtype, op, t1, comm](ccl_sched* s) { @@ -313,7 +285,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, "reduce_bcast_t1"); } else { - entry_factory::make_entry( + entry_factory::create( sched, t1_op_id, [t1_work_buf, t1_work_count, &dtype, op, t1, comm](ccl_sched* s) { @@ -321,7 +293,7 @@ ccl::status ccl_coll_build_double_tree_op(ccl_sched* sched, }, "reduce_bcast_t1"); - entry_factory::make_entry( + entry_factory::create( sched, t2_op_id, [t2_work_buf, t2_work_count, &dtype, op, t2, comm](ccl_sched* s) { diff --git a/src/coll/algorithms/reduce.cpp b/src/coll/algorithms/reduce.cpp index 54a9a55d0..14eb503d9 100644 --- a/src/coll/algorithms/reduce.cpp +++ b/src/coll/algorithms/reduce.cpp @@ -21,8 +21,12 @@ */ #include "coll/algorithms/algorithms.hpp" +#include "common/comm/comm.hpp" #include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/entry_factory.hpp" +#if defined(CCL_ENABLE_ZE) && defined(CCL_ENABLE_SYCL) +#include "coll/coll_util.hpp" +#endif // CCL_ENABLE_ZE && CCL_ENABLE_SYCL /* An implementation of Rabenseifner's reduce algorithm (see http://www.hlrs.de/mpi/myreduce.html). @@ -62,7 +66,7 @@ ccl::status ccl_coll_build_direct_reduce(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct reduce"); - entry_factory::make_entry( + entry_factory::create( sched, send_buf, recv_buf, count, dtype, reduction, root, comm); return ccl::status::success; } @@ -89,7 +93,7 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, comm_size = comm->size(); rank = comm->rank(); - ccl_buffer tmp_buf = sched->alloc_buffer(count * dtype_size); + ccl_buffer tmp_buf = sched->alloc_buffer({ count * dtype_size, send_buf }); /* get nearest power-of-two less than or equal to comm_size */ pof2 = comm->pof2(); @@ -99,11 +103,11 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, /* If I'm not the root, then my recv_buf may not be valid, therefore * I have to allocate a temporary one */ if (rank != local_root) { - recv_buf = sched->alloc_buffer(count * dtype_size); + recv_buf = sched->alloc_buffer({ count * dtype_size, send_buf }); } if ((rank != local_root) || (send_buf != recv_buf)) - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); /* In the non-power-of-two case, all odd-numbered * processes of rank < 2*rem send their data to @@ -123,7 +127,7 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, if (rank < 2 * rem) { if (rank % 2 != 0) { /* odd */ - entry_factory::make_entry(sched, recv_buf, count, dtype, rank - 1, comm); + entry_factory::create(sched, recv_buf, count, dtype, rank - 1, comm); sched->add_barrier(); /* temporarily set the rank to -1 so that this @@ -132,13 +136,13 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, new_rank = CCL_INVALID_PROC_IDX; } else { /* even */ - entry_factory::make_entry(sched, tmp_buf, count, dtype, rank + 1, comm); + entry_factory::create(sched, tmp_buf, count, dtype, rank + 1, comm); sched->add_barrier(); /* do the reduction on received data. */ /* This algorithm is used only for predefined ops * and predefined ops are always commutative. */ - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf, count, recv_buf, nullptr, dtype, reduction); sched->add_barrier(); @@ -196,10 +200,10 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, } /* Send data from recv_buf. Recv into tmp_buf */ - entry_factory::make_entry( + entry_factory::create( sched, (recv_buf + disps[send_idx] * dtype_size), send_cnt, dtype, dst, comm); /* sendrecv, no barrier here */ - entry_factory::make_entry( + entry_factory::create( sched, (tmp_buf + disps[recv_idx] * dtype_size), recv_cnt, dtype, dst, comm); sched->add_barrier(); @@ -208,13 +212,13 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, /* This algorithm is used only for predefined ops * and predefined ops are always commutative. */ - entry_factory::make_entry(sched, - (tmp_buf + disps[recv_idx] * dtype_size), - recv_cnt, - (recv_buf + disps[recv_idx] * dtype_size), - nullptr, - dtype, - reduction); + entry_factory::create(sched, + (tmp_buf + disps[recv_idx] * dtype_size), + recv_cnt, + (recv_buf + disps[recv_idx] * dtype_size), + nullptr, + dtype, + reduction); sched->add_barrier(); /* update send_idx for next iteration */ @@ -247,7 +251,7 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, for (i = 1; i < pof2; i++) disps[i] = disps[i - 1] + cnts[i - 1]; - entry_factory::make_entry(sched, recv_buf, cnts[0], dtype, 0, comm); + entry_factory::create(sched, recv_buf, cnts[0], dtype, 0, comm); sched->add_barrier(); new_rank = 0; @@ -255,7 +259,7 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, last_idx = 2; } else if (new_rank == 0) { /* send */ - entry_factory::make_entry( + entry_factory::create( sched, recv_buf, cnts[0], dtype, local_root, comm); sched->add_barrier(); @@ -322,14 +326,14 @@ ccl::status ccl_coll_build_rabenseifner_reduce(ccl_sched* sched, if (newdst_tree_root == newroot_tree_root) { /* send and exit */ /* Send data from recv_buf. Recv into tmp_buf */ - entry_factory::make_entry( + entry_factory::create( sched, (recv_buf + disps[send_idx] * dtype_size), send_cnt, dtype, dst, comm); sched->add_barrier(); break; } else { /* recv and continue */ - entry_factory::make_entry( + entry_factory::create( sched, (recv_buf + disps[recv_idx] * dtype_size), recv_cnt, dtype, dst, comm); sched->add_barrier(); } @@ -372,16 +376,16 @@ ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, /* Create a temporary buffer */ size_t dtype_size = dtype.size(); - ccl_buffer tmp_buf = sched->alloc_buffer(count * dtype_size); + ccl_buffer tmp_buf = sched->alloc_buffer({ count * dtype_size, send_buf }); /* If I'm not the root, then my recv_buf may not be valid, therefore * I have to allocate a temporary one */ if (rank != local_root) { - recv_buf = sched->alloc_buffer(count * dtype_size); + recv_buf = sched->alloc_buffer({ count * dtype_size, send_buf }); } if ((rank != local_root) || (send_buf != recv_buf)) { - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } @@ -427,10 +431,10 @@ ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, if (source < comm_size) { source = (source + lroot) % comm_size; - entry_factory::make_entry(sched, tmp_buf, count, dtype, source, comm); + entry_factory::create(sched, tmp_buf, count, dtype, source, comm); sched->add_barrier(); - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf, count, recv_buf, nullptr, dtype, reduction); sched->add_barrier(); } @@ -439,7 +443,7 @@ ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, /* I've received all that I'm going to. Send my result to * my parent */ source = ((relrank & (~mask)) + lroot) % comm_size; - entry_factory::make_entry(sched, recv_buf, count, dtype, source, comm); + entry_factory::create(sched, recv_buf, count, dtype, source, comm); sched->add_barrier(); break; } @@ -449,55 +453,154 @@ ccl::status ccl_coll_build_binomial_reduce(ccl_sched* sched, return status; } -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) ccl::status ccl_coll_build_gpu_reduce(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, size_t count, const ccl_datatype& dtype, - ccl::reduction reduction, + ccl::reduction op, int root, ccl_comm* comm) { LOG_DEBUG("build gpu reduce"); - int skip_rank = -1; + ccl_comm* pair_comm = comm->get_pair_comm().get(); + ccl_comm* even_comm = comm->get_even_comm().get(); + ccl_comm* node_comm = comm->get_node_comm().get(); + ccl_comm* r2r_comm = comm->get_r2r_comm().get(); - const std::vector in_buffers{ + int comm_size = comm->size(); + int even_comm_size = even_comm->size(); + int node_comm_size = node_comm->size(); + + bool is_single_node = (comm_size == node_comm_size); + bool is_single_card = (comm_size == 2) && is_single_node; + bool use_tmp_buf = !is_single_card; + + ccl_buffer tmp_buf{}; + ccl::alloc_param alloc_param( + count * dtype.size(), ccl::buffer_type::ze, ccl::buffer_place::device); + if (use_tmp_buf) { + tmp_buf = sched->alloc_buffer(alloc_param); + } + + std::vector in_buffers{ { send_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 + { recv_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 1 }; - ccl_coll_entry_param barrier_param{}; - barrier_param.ctype = ccl_coll_barrier; - barrier_param.comm = comm; - barrier_param.hint_algo.barrier = ccl_coll_barrier_ring; + size_t recv_buf_idx = 1; + size_t tmp_buf_idx = std::numeric_limits::max(); + if (use_tmp_buf) { + tmp_buf_idx = in_buffers.size(); + in_buffers.push_back({ tmp_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }); + } - if (sched->coll_attr.to_cache) { - sched->set_entry_exec_mode(ccl_sched_entry_exec_once); - entry_factory::make_entry(sched, comm, in_buffers, skip_rank); - sched->add_barrier(); - sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); + ccl::add_handle_exchange(sched, node_comm, in_buffers); - // TODO: no need barrier for the first iteration where ze_handle_exchange_entry exists - // TODO: think about the right way - coll_entry_helper::add_coll_entry(sched, barrier_param); + if (is_single_card) { + LOG_DEBUG("topo/scale_up/intra: use ze_onesided_reduce"); + if (comm->rank() == root) { + entry_factory::create( + sched, send_buf, recv_buf, count, dtype, op, root, pair_comm); + sched->add_barrier(); + } + + ccl::add_comm_barrier(sched, pair_comm); } else { - entry_factory::make_entry(sched, comm, in_buffers, skip_rank); - } + if (pair_comm->rank() == ccl::global_data::env().kernel_1s_lead) { + LOG_DEBUG("topo/scale_up/intra: use ze_onesided_reduce"); + entry_factory::create( + sched, send_buf, tmp_buf, count, dtype, op, pair_comm->rank(), pair_comm); + sched->add_barrier(); - sched->add_barrier(); + size_t main_block_count = count / even_comm_size; + size_t block_count = main_block_count; + if (even_comm->rank() == even_comm_size - 1) { + block_count += count % even_comm_size; + } - if (comm->rank() == root) { - entry_factory::make_entry( - sched, send_buf, recv_buf, count, dtype, reduction, root, comm); - sched->add_barrier(); - } + ccl::add_comm_barrier(sched, even_comm); + size_t offset_bytes = main_block_count * even_comm->rank() * dtype.size(); + ccl_buffer partial_tmp_buf = tmp_buf + offset_bytes; + LOG_DEBUG("topo/scale_up/inter: use ze_a2a_reduce_scatter_entry"); + std::vector wait_events; + std::vector block_counts(even_comm->size(), main_block_count); + block_counts[even_comm->size() - 1] = block_count; + entry_factory::create(sched, + tmp_buf, + partial_tmp_buf, + block_counts.data(), + dtype, + op, + even_comm, + wait_events, + tmp_buf_idx); + sched->add_barrier(); + ccl::add_comm_barrier(sched, even_comm); + + CCL_THROW_IF_NOT(comm->size() % node_comm_size == 0); + int root_node_idx = root / node_comm_size; + ccl_buffer host_buf{}; + if (!is_single_node && block_count) { + LOG_DEBUG("topo/scale_out: use host_reduce"); + ccl::alloc_param alloc_param( + block_count * dtype.size(), ccl::buffer_type::regular, ccl::buffer_place::host); + host_buf = sched->alloc_buffer(alloc_param); + entry_factory::create(sched, + partial_tmp_buf, + host_buf, + block_count, + dtype, + copy_attr(copy_direction::d2h)); + sched->add_barrier(); - // TODO: think about the right way - coll_entry_helper::add_coll_entry(sched, barrier_param); + LOG_DEBUG("rank: ", + comm->rank(), + ", reduce to rank on r2r_comm: ", + root_node_idx, + ", count: ", + block_count); + ccl_coll_build_reduce( + sched, host_buf, host_buf, block_count, dtype, op, root_node_idx, r2r_comm); + sched->add_barrier(); + } + + if (root_node_idx == r2r_comm->rank()) { + LOG_DEBUG("topo/scale_up/intra: use ze_onesided_bcast"); + int root_in_node_comm = node_comm->get_rank_from_global(root); + size_t offset_count = offset_bytes / dtype.size(); + ccl_buffer src = (!is_single_node && block_count) ? host_buf : partial_tmp_buf; + ccl_buffer dst{}; + copy_attr attr(root_in_node_comm, + recv_buf_idx, + copy_direction::h2d, + node_comm, + 0, + offset_count); + if (comm->rank() == root) { + dst = recv_buf; + attr = copy_attr(copy_direction::h2d, 0, offset_count); + } + + LOG_DEBUG("rank: ", + comm->rank(), + ", copy to rank on node_comm: ", + root_in_node_comm, + ", offset count: ", + offset_count, + ", count: ", + block_count); + entry_factory::create(sched, src, dst, block_count, dtype, attr); + sched->add_barrier(); + } + } + ccl::add_comm_barrier(sched, node_comm); + } return ccl::status::success; } -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE diff --git a/src/coll/algorithms/reduce_scatter.cpp b/src/coll/algorithms/reduce_scatter.cpp index 4a6774ff7..08e84be57 100644 --- a/src/coll/algorithms/reduce_scatter.cpp +++ b/src/coll/algorithms/reduce_scatter.cpp @@ -21,7 +21,11 @@ */ #include "coll/algorithms/algorithms.hpp" +#include "sched/entry/coll/coll_entry_helper.hpp" #include "sched/entry/factory/entry_factory.hpp" +#if defined(CCL_ENABLE_ZE) && defined(CCL_ENABLE_SYCL) +#include "coll/coll_util.hpp" +#endif // CCL_ENABLE_ZE && CCL_ENABLE_SYCL ccl::status ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, ccl_buffer send_buf, @@ -32,7 +36,7 @@ ccl::status ccl_coll_build_direct_reduce_scatter(ccl_sched* sched, ccl_comm* comm) { LOG_DEBUG("build direct reduce_scatter"); - entry_factory::make_entry( + entry_factory::create( sched, send_buf, recv_buf, recv_count, dtype, reduction, comm); return ccl::status::success; } @@ -70,12 +74,12 @@ ccl::status ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, if (!inplace) { /* copy local data into recv_buf */ - entry_factory::make_entry( + entry_factory::create( sched, send_buf + rank * recv_count * dtype_size, recv_buf, recv_count, dtype); } /* allocate temporary buffer to store incoming data */ - ccl_buffer tmp_buf = sched->alloc_buffer(recv_count * dtype_size); + ccl_buffer tmp_buf = sched->alloc_buffer({ recv_count * dtype_size, recv_buf }); for (idx = 1; idx < comm_size; idx++) { src = (comm_size + rank - idx) % comm_size; @@ -84,39 +88,39 @@ ccl::status ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, /* send the data that dst needs. recv data that this process * needs from src into tmp_recvbuf */ if (!inplace) { - entry_factory::make_entry( + entry_factory::create( sched, send_buf + dst * recv_count * dtype_size, recv_count, dtype, dst, comm); - entry_factory::make_entry(sched, tmp_buf, recv_count, dtype, src, comm); + entry_factory::create(sched, tmp_buf, recv_count, dtype, src, comm); } else { - entry_factory::make_entry( + entry_factory::create( sched, recv_buf + dst * recv_count * dtype_size, recv_count, dtype, dst, comm); - entry_factory::make_entry(sched, tmp_buf, recv_count, dtype, src, comm); + entry_factory::create(sched, tmp_buf, recv_count, dtype, src, comm); } sched->add_barrier(); if (!inplace) { - entry_factory::make_entry( + entry_factory::create( sched, tmp_buf, recv_count, recv_buf, nullptr, dtype, op); } else { - entry_factory::make_entry(sched, - tmp_buf, - recv_count, - recv_buf + rank * recv_count * dtype_size, - nullptr, - dtype, - op); + entry_factory::create(sched, + tmp_buf, + recv_count, + recv_buf + rank * recv_count * dtype_size, + nullptr, + dtype, + op); } } /* if inplace, move output data to the beginning of * recv_buf. already done for rank 0 */ if (inplace && (rank != 0)) { - entry_factory::make_entry( + entry_factory::create( sched, recv_buf + rank * recv_count * dtype_size, recv_buf, recv_count, dtype); } @@ -127,7 +131,7 @@ ccl::status ccl_coll_build_ring_reduce_scatter_block(ccl_sched* sched, ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, ccl_buffer send_buf, ccl_buffer recv_buf, - size_t send_count, + size_t recv_count, const ccl_datatype& dtype, ccl::reduction op, ccl_comm* comm) { @@ -151,7 +155,7 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, int src = (comm_size + rank - 1) % comm_size; int dst = (comm_size + rank + 1) % comm_size; - size_t count = send_count; + size_t count = recv_count; size_t bytes = count * dtype_size; size_t chunk_count = @@ -178,7 +182,7 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, if (comm_size == 1) { if (!inplace) { - entry_factory::make_entry(sched, send_buf, recv_buf, count, dtype); + entry_factory::create(sched, send_buf, recv_buf, count, dtype); sched->add_barrier(); } return ccl::status::success; @@ -187,7 +191,7 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, ccl_buffer tmp_buf; if (inplace) { - tmp_buf = sched->alloc_buffer(count * dtype_size); + tmp_buf = sched->alloc_buffer({ count * dtype_size, recv_buf }); } ccl_buffer sbuf, rbuf; @@ -208,7 +212,7 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, size_t send_main_chunk_size, send_last_chunk_size; size_t recv_main_chunk_size, recv_last_chunk_size; - size_t send_chunk_size, recv_chunk_size, reduce_chunk_size; + size_t send_chunk_size, recv_chunk_size = 0, reduce_chunk_size; size_t send_chunk_offset, recv_chunk_offset = 0, reduce_chunk_offset; /* if chunk_count > 1 then make reduction with 1 chunk delay to get comp/comp overlapping */ @@ -278,33 +282,31 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, recv_reduce_local_buf += reduce_chunk_offset; recv_reduce_comm_buf += reduce_chunk_offset; - entry_factory::make_entry(sched, sbuf, send_chunk_size, dtype, dst, comm); + entry_factory::create(sched, sbuf, send_chunk_size, dtype, dst, comm); if (!use_prev) { CCL_ASSERT(recv_chunk_size == reduce_chunk_size); - entry_factory::make_entry(sched, - recv_reduce_local_buf, - recv_chunk_size, - nullptr, /* out_cnt */ - dtype, - op, - src, - recv_reduce_comm_buf, - comm, - recv_reduce_result_type); + entry_factory::create(sched, + recv_reduce_local_buf, + recv_chunk_size, + dtype, + op, + src, + comm, + recv_reduce_comm_buf, + recv_reduce_result_type); } else { - entry_factory::make_entry( - sched, rbuf, recv_chunk_size, dtype, src, comm); + entry_factory::create(sched, rbuf, recv_chunk_size, dtype, src, comm); if (idx + chunk_idx > 0) { - entry_factory::make_entry(sched, - reduce_in_buf, - reduce_chunk_size, - reduce_inout_buf, - nullptr, - dtype, - op); + entry_factory::create(sched, + reduce_in_buf, + reduce_chunk_size, + reduce_inout_buf, + nullptr, + dtype, + op); sched->add_barrier(); } @@ -324,13 +326,13 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, reduce_in_buf += recv_chunk_offset; reduce_inout_buf += recv_chunk_offset; - entry_factory::make_entry(sched, - reduce_in_buf, - recv_chunk_size, - reduce_inout_buf, - nullptr, - dtype, - op); + entry_factory::create(sched, + reduce_in_buf, + recv_chunk_size, + reduce_inout_buf, + nullptr, + dtype, + op); } } @@ -343,3 +345,42 @@ ccl::status ccl_coll_build_ring_reduce_scatter(ccl_sched* sched, return status; } + +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + +ccl::status ccl_coll_build_topo_reduce_scatter(ccl_sched* sched, + ccl_buffer send_buf, + ccl_buffer recv_buf, + size_t recv_count, + const ccl_datatype& dtype, + ccl::reduction reduction, + ccl_comm* comm) { + LOG_DEBUG("build topo reduce_scatter, recv_count ", recv_count); + + const std::vector in_buffers{ + { send_buf.get_ptr(), ccl::ze::ipc_mem_type::memory }, // 0 + }; + + size_t send_buf_idx = 0; + + ccl::add_handle_exchange(sched, comm, in_buffers); + + std::vector wait_events; + std::vector blocks_count(comm->size(), recv_count); + entry_factory::create(sched, + send_buf, + recv_buf, + blocks_count.data(), + dtype, + reduction, + comm, + wait_events, + send_buf_idx); + sched->add_barrier(); + + ccl::add_comm_barrier(sched, comm); + + return ccl::status::success; +} + +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE diff --git a/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp b/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp index 32c341a7a..699c01f82 100644 --- a/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp +++ b/src/coll/algorithms/sparse_allreduce/sparse_allreduce.hpp @@ -15,6 +15,7 @@ */ #include "oneapi/ccl/type_traits.hpp" #include "coll/algorithms/sparse_allreduce/sparse_handler.hpp" +#include "common/utils/memcpy.hpp" #include "sched/entry/factory/entry_factory.hpp" #define CCL_COALESCE_RESERVE_SIZE 16 @@ -163,7 +164,7 @@ param_nnz.dtype = ccl_datatype_int8; \ param_nnz.comm = comm; \ \ - entry_factory::make_entry(sched, param_nnz); \ + entry_factory::create(sched, param_nnz); \ sched->add_barrier(); \ } while (0) @@ -357,10 +358,8 @@ ccl::status sparse_reduce_ring(const void* ctx) { std::vector buf_v(merge_idx_len * sa_hndl->val_dim_cnt); /* copy what we already have reduced*/ - ccl_comp_copy( - snd_i, buf_i.data(), sa_hndl->itype_size * sa_hndl->dst_count[0], ccl_datatype_int8); - ccl_comp_copy( - snd_v, buf_v.data(), sa_hndl->vtype_size * sa_hndl->dst_count[1], ccl_datatype_int8); + ccl_comp_copy(snd_i, buf_i.data(), sa_hndl->itype_size * sa_hndl->dst_count[0]); + ccl_comp_copy(snd_v, buf_v.data(), sa_hndl->vtype_size * sa_hndl->dst_count[1]); size_t idx_offset = 0; for (auto id : unique_indices_ids) { @@ -389,15 +388,12 @@ ccl::status sparse_reduce_ring(const void* ctx) { new_dst_size)) .get_ptr(); - ccl_comp_copy(buf_i.data(), - (i_type*)(sa_hndl->dst_buf), - sa_hndl->itype_size * merge_idx_len, - ccl_datatype_int8); + ccl_comp_copy( + buf_i.data(), (i_type*)(sa_hndl->dst_buf), sa_hndl->itype_size * merge_idx_len); ccl_comp_copy(buf_v.data(), (v_type*)((char*)(sa_hndl->dst_buf) + sa_hndl->itype_size * merge_idx_len), - sa_hndl->vtype_size * merge_idx_len * sa_hndl->val_dim_cnt, - ccl_datatype_int8); + sa_hndl->vtype_size * merge_idx_len * sa_hndl->val_dim_cnt); sa_hndl->dst_count[0] = merge_idx_len; sa_hndl->dst_count[1] = merge_idx_len * sa_hndl->val_dim_cnt; @@ -406,8 +402,7 @@ ccl::status sparse_reduce_ring(const void* ctx) { ccl_comp_copy(sa_hndl->recv_buf, sa_hndl->send_tmp_buf, - idx_size + sa_hndl->send_count[1] * sa_hndl->vtype_size, - ccl_datatype_int8); + idx_size + sa_hndl->send_count[1] * sa_hndl->vtype_size); sa_hndl->iter++; @@ -489,7 +484,7 @@ ccl::status sparse_set_max_buf_size_ring(const void* ctx) { size_t max_size = max_nnz * common_size_part; sa_hndl->send_tmp_buf = sa_hndl->sched->alloc_buffer(max_size).get_ptr(); - CCL_MEMCPY(sa_hndl->send_tmp_buf, sa_hndl->dst_buf, sa_hndl->dst_count[0] * common_size_part); + ccl::memcpy(sa_hndl->send_tmp_buf, sa_hndl->dst_buf, sa_hndl->dst_count[0] * common_size_part); sa_hndl->recv_buf = sa_hndl->sched->alloc_buffer(max_size).get_ptr(); return ccl::status::success; @@ -505,7 +500,7 @@ ccl::status sparse_coalesce_ring(const void* ctx) { sa_hndl->send_count[0] = iv_map_cnt; /* index count */ sa_hndl->send_count[1] = iv_map_cnt * sa_hndl->val_dim_cnt; /* value count */ - CCL_MEMCPY(&sa_hndl->dst_count, &sa_hndl->send_count, sizeof(size_t) * 2); + ccl::memcpy(&sa_hndl->dst_count, &sa_hndl->send_count, sizeof(size_t) * 2); CCL_SPARSE_ALLREDUCE_IF_SINGLE_RANK(); return ccl::status::success; @@ -559,37 +554,37 @@ ccl::status ccl_coll_build_sparse_allreduce_ring(ccl_sched* sched, sa_hndl->recv_counts = static_cast(sched->alloc_buffer(sizeof(size_t) * comm_size).get_ptr()); - entry_factory::make_entry(sched, sparse_coalesce_ring, sa_hndl); + entry_factory::create(sched, sparse_coalesce_ring, sa_hndl); sched->add_barrier(); if (comm_size > 1) { CCL_SPARSE_ALLREDUCE_ADD_NNZ_ENTRY(); - entry_factory::make_entry(sched, sparse_set_max_buf_size_ring, sa_hndl); + entry_factory::create(sched, sparse_set_max_buf_size_ring, sa_hndl); sched->add_barrier(); for (int i = 0; i < comm_size - 1; i++) { /* send local data to the right neighbour */ - send_entry* se = entry_factory::make_entry( + send_entry* se = entry_factory::create( sched, ccl_buffer(), 0, ccl_datatype_int8, send_to, comm); se->set_field_fn(sparse_get_send_buf_ring, sa_hndl); se->set_field_fn(sparse_get_send_count_ring, sa_hndl); /* receive data from the left neighbour */ - recv_entry* re = entry_factory::make_entry( + recv_entry* re = entry_factory::create( sched, ccl_buffer(), 0, ccl_datatype_int8, recv_from, comm); re->set_field_fn(sparse_get_recv_buf_ring, sa_hndl); re->set_field_fn(sparse_get_recv_count_ring, sa_hndl); sched->add_barrier(); /* reduce data */ - entry_factory::make_entry( + entry_factory::create( sched, sparse_reduce_ring, sa_hndl); sched->add_barrier(); } /* copy all reduced data to recv_buf */ - entry_factory::make_entry( + entry_factory::create( sched, sparse_prepare_result_ring, sa_hndl); sched->add_barrier(); } @@ -626,9 +621,9 @@ ccl::status sparse_create_matrix_mask(const void* ctx) { auto elem = sa_hndl->iv_map->find(*it); if (elem != sa_hndl->iv_map->end()) { /* copy values from dst_buf to matrix */ - CCL_MEMCPY(matrix + idx_offset * sa_hndl->val_dim_cnt, - values + elem->second[0], - value_line_size); + ccl::memcpy(matrix + idx_offset * sa_hndl->val_dim_cnt, + values + elem->second[0], + value_line_size); } else { /* no index was found locally, fill the line with mask */ @@ -647,10 +642,7 @@ ccl::status sparse_create_matrix_mask(const void* ctx) { sa_hndl->vtype_size * sa_hndl->dst_count[1]) .get_ptr(); - ccl_comp_copy(matrix, - (char*)sa_hndl->dst_buf + idx_cnt * sa_hndl->itype_size, - matrix_size, - ccl_datatype_int8); + ccl_comp_copy(matrix, (char*)sa_hndl->dst_buf + idx_cnt * sa_hndl->itype_size, matrix_size); CCL_FREE(matrix); sa_hndl->iv_map->clear(); @@ -765,13 +757,13 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, sa_hndl->recv_counts = static_cast(sched->alloc_buffer(sizeof(size_t) * comm_size).get_ptr()); - entry_factory::make_entry(sched, sparse_coalesce_mask, sa_hndl); + entry_factory::create(sched, sparse_coalesce_mask, sa_hndl); sched->add_barrier(); if (comm_size > 1) { CCL_SPARSE_ALLREDUCE_ADD_NNZ_ENTRY(); - entry_factory::make_entry(sched, sparse_nnz_per_rank_mask, sa_hndl); + entry_factory::create(sched, sparse_nnz_per_rank_mask, sa_hndl); sched->add_barrier(); ccl_coll_entry_param param_allgatherv{}; @@ -784,13 +776,13 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, param_allgatherv.comm = comm; /* gather indices from all the processes */ - coll_entry* e = entry_factory::make_entry(sched, param_allgatherv); + coll_entry* e = entry_factory::create(sched, param_allgatherv); e->set_field_fn(sparse_get_send_buf_mask, sa_hndl); e->set_field_fn(sparse_get_allgatherv_buf_mask, sa_hndl); e->set_field_fn(sparse_get_send_count_mask, sa_hndl); sched->add_barrier(); - entry_factory::make_entry( + entry_factory::create( sched, sparse_create_matrix_mask, sa_hndl); sched->add_barrier(); @@ -804,7 +796,7 @@ ccl::status ccl_coll_build_sparse_allreduce_mask(ccl_sched* sched, param_allreduce.comm = comm; /* coll allreduce on matrix data */ - coll_entry* ce = entry_factory::make_entry(sched, param_allreduce); + coll_entry* ce = entry_factory::create(sched, param_allreduce); ce->set_field_fn(sparse_get_allreduce_buf_mask, sa_hndl); ce->set_field_fn(sparse_get_allreduce_buf_mask, sa_hndl); ce->set_field_fn(sparse_get_allreduce_count_mask, sa_hndl); @@ -1085,7 +1077,7 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, sa_hndl->recv_counts); if (sched->coll_attr.sparse_coalesce_mode != ccl::sparse_coalesce_mode::disable) { - entry_factory::make_entry( + entry_factory::create( sched, sparse_coalesce_allgatherv, sa_hndl); sched->add_barrier(); @@ -1099,7 +1091,7 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, CCL_SPARSE_ALLREDUCE_ADD_NNZ_ENTRY(); - entry_factory::make_entry(sched, sparse_alloc_result_buf_allgatherv, sa_hndl); + entry_factory::create(sched, sparse_alloc_result_buf_allgatherv, sa_hndl); sched->add_barrier(); // allgather indices @@ -1113,12 +1105,12 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, param_i.dtype = index_dtype; param_i.comm = comm; - coll_entry* ce = entry_factory::make_entry(sched, param_i, parallel_request_index); + coll_entry* ce = entry_factory::create(sched, param_i, parallel_request_index); ce->set_field_fn(sparse_get_i_send_allgatherv, sa_hndl); ce->set_field_fn(sparse_get_i_recv_allgatherv, sa_hndl); ce->set_field_fn(sparse_get_send_count_allgatherv<0>, sa_hndl); - entry_factory::make_entry(sched, sparse_set_v_counts_allgatherv<1>, sa_hndl); + entry_factory::create(sched, sparse_set_v_counts_allgatherv<1>, sa_hndl); // allgather values parallel_request_index++; @@ -1131,7 +1123,7 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, param_v.dtype = value_dtype; param_v.comm = comm; - ce = entry_factory::make_entry(sched, param_v, parallel_request_index); + ce = entry_factory::create(sched, param_v, parallel_request_index); ce->set_field_fn(sparse_get_v_send_allgatherv, sa_hndl); ce->set_field_fn(sparse_get_v_recv_allgatherv, sa_hndl); ce->set_field_fn(sparse_get_send_count_allgatherv<1>, @@ -1139,11 +1131,10 @@ ccl::status ccl_coll_build_sparse_allreduce_3_allgatherv(ccl_sched* sched, sched->add_barrier(); if (sched->coll_attr.sparse_coalesce_mode == ccl::sparse_coalesce_mode::disable) { - entry_factory::make_entry( - sched, sparse_return_gathered_allgatherv, sa_hndl); + entry_factory::create(sched, sparse_return_gathered_allgatherv, sa_hndl); } else { - entry_factory::make_entry( + entry_factory::create( sched, sparse_reduce_gathered_allgatherv, sa_hndl); } sched->add_barrier(); diff --git a/src/coll/coll.cpp b/src/coll/coll.cpp index 0dca06c03..8af985d79 100644 --- a/src/coll/coll.cpp +++ b/src/coll/coll.cpp @@ -51,7 +51,7 @@ #include "common/global/global.hpp" #include "coll/algorithms/algorithms.hpp" -#include "coll/algorithms/algorithms_enum.hpp" +#include "coll/algorithms/algorithm_utils.hpp" #include "coll/algorithms/allreduce/allreduce_2d.hpp" #include "coll/algorithms/sparse_allreduce/sparse_allreduce.hpp" #include "coll/selection/selection.hpp" @@ -63,6 +63,13 @@ static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& in_attr) { ccl_coll_attr& attr = const_cast(in_attr); +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + uint64_t operation_create_time = 0; + if (ccl::global_data::env().enable_kernel_profile && param.stream) { + operation_create_time = ccl::ze::calculate_global_time(param.stream->get_ze_device()); + } +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE + #ifdef CCL_ENABLE_SYCL if (ccl::global_data::env().enable_op_sync) attr.synchronous = 1; @@ -85,8 +92,9 @@ static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& bool postpone_schedule = false; if (ccl::global_data::env().enable_unordered_coll) { if (!attr.match_id.empty()) { - auto comm = - param.comm->unordered_coll_manager->get_comm(std::string(attr.match_id)).get(); + auto comm = param.comm->get_unordered_coll_manager() + ->get_comm(std::string(attr.match_id)) + .get(); if (!comm) { if (attr.synchronous) { CCL_THROW("unsupported collective (synchronous && unordered && !communicator)"); @@ -107,6 +115,12 @@ static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& /* 2. create or get schedule */ ccl_master_sched* sched = ccl_master_sched::create(param, attr); +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + if (ccl::global_data::env().enable_kernel_profile && param.stream) { + sched->get_kernel_timer().set_operation_create_time(operation_create_time); + } +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE + /* 3. fuse schedule */ if (!postpone_schedule && ccl::global_data::env().enable_fusion) { if (data.fusion_manager->add(sched)) { @@ -128,7 +142,7 @@ static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& user has provided match_id that has not been resolved yet. schedule will be postponed until comm resolution */ - return param.comm->unordered_coll_manager->postpone(sched); + return param.comm->get_unordered_coll_manager()->postpone(sched); } /* 6. regular schedule execution */ @@ -138,6 +152,13 @@ static ccl_request* ccl_coll_create(ccl_coll_param& param, const ccl_coll_attr& request = nullptr; } +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + if (ccl::global_data::env().enable_kernel_profile && sched->coll_param.stream) { + sched->get_kernel_timer().set_operation_start_time( + ccl::ze::calculate_global_time(sched->coll_param.stream->get_ze_device())); + } +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE + return request; } @@ -155,7 +176,12 @@ ccl::status ccl_coll_build_allgatherv(ccl_sched* sched, param.recv_counts = recv_counts; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; + param.buf = send_buf.get_ptr(); param.is_vector_buf = sched->coll_attr.is_vector_buf; +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -173,6 +199,12 @@ ccl::status ccl_coll_build_allgatherv(ccl_sched* sched, CCL_CALL(ccl_coll_build_ring_allgatherv( sched, send_buf, send_count, recv_buf, recv_counts, dtype, comm)); break; +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + case ccl_coll_allgatherv_topo: + CCL_CALL(ccl_coll_build_topo_allgatherv( + sched, send_buf, send_count, recv_buf, recv_counts, dtype, comm)); + break; +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE default: CCL_FATAL("unexpected allgatherv_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -214,8 +246,8 @@ ccl::status ccl_coll_build_allreduce(ccl_sched* sched, CCL_CALL(ccl_coll_build_rabenseifner_allreduce( sched, send_buf, recv_buf, count, dtype, reduction, comm)); break; - case ccl_coll_allreduce_starlike: - CCL_CALL(ccl_coll_build_starlike_allreduce( + case ccl_coll_allreduce_nreduce: + CCL_CALL(ccl_coll_build_nreduce_allreduce( sched, send_buf, recv_buf, count, dtype, reduction, comm)); break; case ccl_coll_allreduce_ring: @@ -242,15 +274,15 @@ ccl::status ccl_coll_build_allreduce(ccl_sched* sched, sched, send_buf, recv_buf, count, dtype, reduction, comm)); break; case ccl_coll_allreduce_2d: - CCL_CALL(comm->allreduce_2d_builder->build( + CCL_CALL(comm->get_allreduce_2d_builder()->build( sched, send_buf, recv_buf, count, dtype, reduction)); break; -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) - case ccl_coll_allreduce_topo_ring: - CCL_CALL(ccl_coll_build_gpu_allreduce( +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + case ccl_coll_allreduce_topo: + CCL_CALL(ccl_coll_build_topo_allreduce( sched, send_buf, recv_buf, count, dtype, reduction, comm)); break; -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE default: CCL_FATAL("unexpected allreduce_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -272,6 +304,10 @@ ccl::status ccl_coll_build_alltoall(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -301,6 +337,10 @@ ccl::status ccl_coll_build_alltoallv(ccl_sched* sched, param.ctype = ccl_coll_alltoallv; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -357,6 +397,7 @@ ccl::status ccl_coll_build_bcast(ccl_sched* sched, param.dtype = dtype; param.comm = comm; param.stream = sched->coll_param.stream; + param.buf = buf.get_ptr(); #ifdef CCL_ENABLE_SYCL param.is_sycl_buf = sched->coll_attr.is_sycl_buf; #endif // CCL_ENABLE_SYCL @@ -387,11 +428,11 @@ ccl::status ccl_coll_build_bcast(ccl_sched* sched, case ccl_coll_bcast_naive: CCL_CALL(ccl_coll_build_naive_bcast(sched, buf, count, dtype, root, comm)); break; -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) - case ccl_coll_bcast_topo_ring: +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + case ccl_coll_bcast_topo: CCL_CALL(ccl_coll_build_gpu_bcast(sched, buf, count, dtype, root, comm)); break; -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE default: CCL_FATAL("unexpected bcast_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -408,6 +449,7 @@ ccl::status ccl_coll_build_reduce(ccl_sched* sched, int root, ccl_comm* comm) { ccl::status status = ccl::status::success; + CCL_THROW_IF_NOT(root >= 0 && root < comm->size(), "wrong root"); ccl_selector_param param; param.ctype = ccl_coll_reduce; @@ -415,6 +457,10 @@ ccl::status ccl_coll_build_reduce(ccl_sched* sched, param.dtype = dtype; param.comm = comm; param.stream = sched->coll_param.stream; + param.buf = send_buf.get_ptr(); +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -444,12 +490,12 @@ ccl::status ccl_coll_build_reduce(ccl_sched* sched, root == 0 ? comm->dtree() : comm->dtree().copy_with_new_root(root), comm)); break; -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) - case ccl_coll_reduce_topo_ring: +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + case ccl_coll_reduce_topo: CCL_CALL(ccl_coll_build_gpu_reduce( sched, send_buf, recv_buf, count, dtype, reduction, root, comm)); break; -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE default: CCL_FATAL("unexpected reduce_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; @@ -473,6 +519,11 @@ ccl::status ccl_coll_build_reduce_scatter(ccl_sched* sched, param.count = count; param.dtype = dtype; param.comm = comm; + param.stream = sched->coll_param.stream; + param.buf = send_buf.get_ptr(); +#ifdef CCL_ENABLE_SYCL + param.is_sycl_buf = sched->coll_attr.is_sycl_buf; +#endif // CCL_ENABLE_SYCL param.hint_algo = sched->hint_algo; auto algo = ccl::global_data::get().algorithm_selector->get(param); @@ -494,6 +545,12 @@ ccl::status ccl_coll_build_reduce_scatter(ccl_sched* sched, sched, send_buf, recv_buf, count, dtype, reduction, comm)); } break; +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + case ccl_coll_reduce_scatter_topo: + CCL_CALL(ccl_coll_build_topo_reduce_scatter( + sched, send_buf, recv_buf, count, dtype, reduction, comm)); + break; +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE default: CCL_FATAL("unexpected reduce_scatter_algo ", ccl_coll_algorithm_to_str(algo)); return ccl::status::invalid_arguments; diff --git a/src/coll/coll.hpp b/src/coll/coll.hpp index b455eefe7..e38d9941a 100644 --- a/src/coll/coll.hpp +++ b/src/coll/coll.hpp @@ -15,11 +15,11 @@ */ #pragma once -#include "coll/algorithms/algorithms_enum.hpp" -#include "common/comm/comm.hpp" +#include "coll/algorithms/algorithm_utils.hpp" #include "coll/coll_param.hpp" -#include "common/stream/stream.hpp" +#include "common/comm/comm.hpp" #include "common/datatype/datatype.hpp" +#include "common/stream/stream.hpp" #include "common/utils/buffer.hpp" #include "coll/coll_common_attributes.hpp" diff --git a/src/coll/coll_param.cpp b/src/coll/coll_param.cpp index d3cb9b54b..6483b0625 100644 --- a/src/coll/coll_param.cpp +++ b/src/coll/coll_param.cpp @@ -17,6 +17,7 @@ #include "coll/coll_param.hpp" #include "common/global/global.hpp" +#include "common/utils/sycl_utils.hpp" #define COPY_COMMON_OP_ATTRS(from, to) \ to->prologue_fn = nullptr; /*from.get().get();*/ \ @@ -74,20 +75,6 @@ ccl_coll_attr::ccl_coll_attr(const ccl::sparse_allreduce_attr& attr) { sparse_coalesce_mode = attr.get(); } -bool operator==(const coll_param_gpu& lhs, const coll_param_gpu& rhs) { - CCL_ASSERT((lhs.is_reduction() && rhs.is_reduction()) || - (!lhs.is_reduction() && !rhs.is_reduction())); - - bool res = - lhs.get_coll_type() == rhs.get_coll_type() && lhs.get_datatype() == rhs.get_datatype(); - - if (lhs.is_reduction()) { - res = res && (lhs.get_reduction() == rhs.get_reduction()); - } - - return res; -} - std::string ccl_coll_attr::to_string() const { std::stringstream ss; @@ -144,14 +131,19 @@ std::string ccl_coll_param::to_string() const { ss << "{ "; ss << "coll: " << ccl_coll_type_to_str(ctype); - if (!send_bufs.empty()) - ss << ", sb: " << get_send_buf() << ", sc: " << get_send_count(); + if (!send_bufs.empty()) { + ss << ", sb: " << get_send_buf() + << ", sc: " << std::accumulate(send_counts.begin(), send_counts.end(), 0); + } - if (!recv_bufs.empty()) - ss << ", rb: " << get_recv_buf() << ", rc: " << get_recv_count(); + if (!recv_bufs.empty()) { + ss << ", rb: " << get_recv_buf() + << ", rc: " << std::accumulate(recv_counts.begin(), recv_counts.end(), 0); + } - if (ctype != ccl_coll_barrier) + if (ctype != ccl_coll_barrier) { ss << ", dt: " << ccl::global_data::get().dtypes->name(dtype); + } if (ctype == ccl_coll_allreduce || ctype == ccl_coll_reduce || ctype == ccl_coll_reduce_scatter) { @@ -223,7 +215,15 @@ bool ccl_coll_param::is_inplace(buf_type type) const { } void* send_buf = get_send_buf(0, type); - void* recv_buf = get_recv_buf(0, type); + void* recv_buf = nullptr; + + if ((ctype == ccl_coll_allgatherv) && (recv_bufs.size() > 1)) { + recv_buf = get_recv_buf(comm->rank(), type); + } + else { + recv_buf = get_recv_buf(0, type); + } + return (send_buf && (send_buf == recv_buf)) ? true : false; } @@ -468,7 +468,7 @@ void ccl_coll_param::sync_deps(const ccl_stream* s, const std::vectorget_native_stream().submit_barrier(); + auto sycl_ev = ccl::utils::submit_barrier(s->get_native_stream()); auto e = ccl::create_event(sycl_ev); copy_deps(ds, &e); return; diff --git a/src/coll/coll_param.hpp b/src/coll/coll_param.hpp index 049f5c56a..31e8b4731 100644 --- a/src/coll/coll_param.hpp +++ b/src/coll/coll_param.hpp @@ -17,7 +17,7 @@ #include -#include "coll/algorithms/algorithms_enum.hpp" +#include "coll/algorithms/algorithm_utils.hpp" #include "common/datatype/datatype.hpp" #include "oneapi/ccl.hpp" @@ -228,44 +228,3 @@ struct ccl_coll_param { const ccl_stream* stream, const std::vector& deps = {}); }; - -class coll_param_gpu { - ccl_coll_type ctype; - ccl::datatype dtype; - ccl::reduction red; - -public: - coll_param_gpu(ccl_coll_type ctype, ccl::datatype dtype, ccl::reduction red) - : ctype{ ctype }, - dtype{ dtype }, - red{ red } {} - - coll_param_gpu(ccl_coll_type ctype, ccl::datatype dtype) - : ctype{ ctype }, - dtype{ dtype }, - red{ (ccl::reduction)-1 } { - assert(!is_reduction() && "This constructor is invalid for reduction types"); - } - - ccl_coll_type get_coll_type() const { - return ctype; - } - - ccl::datatype get_datatype() const { - return dtype; - } - - bool is_reduction() const { - return ccl_coll_type_is_reduction(get_coll_type()); - } - - ccl::reduction get_reduction() const { - if (!is_reduction()) { - throw ccl::exception( - "get_ruduction(): is not supported for non-reduction collective type, i.e. bcast"); - } - return red; - } -}; - -bool operator==(const coll_param_gpu& lhs, const coll_param_gpu& rhs); diff --git a/src/coll/coll_util.cpp b/src/coll/coll_util.cpp new file mode 100644 index 000000000..679260dfa --- /dev/null +++ b/src/coll/coll_util.cpp @@ -0,0 +1,104 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#include "coll_util.hpp" + +#include "sched/entry/coll/coll_entry_helper.hpp" +#include "sched/entry/factory/entry_factory.hpp" +#include "sched/entry/ze/ze_event_signal_entry.hpp" +#include "sched/entry/ze/ze_event_wait_entry.hpp" + +namespace ccl { + +void add_wait_events(ccl_sched* sched, const std::vector& wait_events) { + if (wait_events.size() > 0) { + entry_factory::create(sched, wait_events); + sched->add_barrier(); + } +} + +void add_signal_event(ccl_sched* sched, ze_event_handle_t signal_event) { + if (signal_event) { + entry_factory::create(sched, signal_event); + sched->add_barrier(); + } +} + +ze_event_handle_t add_signal_event(ccl_sched* sched) { + auto signal_event = sched->get_memory().event_manager->create(); + add_signal_event(sched, signal_event); + return signal_event; +} + +void add_comm_barrier(ccl_sched* sched, + ccl_comm* comm, + ze_event_pool_handle_t ipc_pool, + size_t ipc_event_idx) { + if (ipc_pool && global_data::env().enable_ze_barrier) { + entry_factory::create(sched, comm, ipc_pool, ipc_event_idx); + } + else { + ccl_coll_entry_param barrier_param{}; + barrier_param.ctype = ccl_coll_barrier; + barrier_param.comm = comm; + + /* TODO: optimize p2p based barrier */ + //barrier_param.hint_algo.barrier = ccl_coll_barrier_ring; + + coll_entry_helper::add_coll_entry(sched, barrier_param); + } + sched->add_barrier(); +} + +ze_event_handle_t add_comm_barrier(ccl_sched* sched, + ccl_comm* comm, + const std::vector& wait_events, + ze_event_pool_handle_t ipc_pool, + size_t ipc_event_idx) { + auto signal_event = sched->get_memory().event_manager->create(); + if (sched->get_memory().use_single_list) { + add_wait_events(sched, wait_events); + add_comm_barrier(sched, comm, ipc_pool, ipc_event_idx); + add_signal_event(sched, signal_event); + } + else { + add_comm_barrier(sched, comm, ipc_pool, ipc_event_idx); + add_signal_event(sched, signal_event); + } + return signal_event; +} + +void add_handle_exchange(ccl_sched* sched, + ccl_comm* comm, + const std::vector& in_buffers, + int skip_rank, + ze_event_pool_handle_t pool, + size_t event_idx) { + if (sched->coll_attr.to_cache) { + sched->set_entry_exec_mode(ccl_sched_entry_exec_once); + entry_factory::create(sched, comm, in_buffers, skip_rank); + sched->add_barrier(); + sched->set_entry_exec_mode(ccl_sched_entry_exec_regular); + + // TODO: no need barrier for the first iteration where ze_handle_exchange_entry exists + add_comm_barrier(sched, comm, pool, event_idx); + } + else { + entry_factory::create(sched, comm, in_buffers, skip_rank); + sched->add_barrier(); + } +} + +} // namespace ccl diff --git a/src/coll/coll_util.hpp b/src/coll/coll_util.hpp new file mode 100644 index 000000000..52f52bab6 --- /dev/null +++ b/src/coll/coll_util.hpp @@ -0,0 +1,45 @@ +/* + Copyright 2016-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +#include "common/global/global.hpp" +#include "sched/entry/ze/ze_handle_exchange_entry.hpp" + +namespace ccl { + +void add_wait_events(ccl_sched* sched, const std::vector& wait_events); +void add_signal_event(ccl_sched* sched, ze_event_handle_t signal_event); +ze_event_handle_t add_signal_event(ccl_sched* sched); + +void add_comm_barrier(ccl_sched* sched, + ccl_comm* comm, + ze_event_pool_handle_t ipc_pool = {}, + size_t ipc_event_idx = 0); + +ze_event_handle_t add_comm_barrier(ccl_sched* sched, + ccl_comm* comm, + const std::vector& wait_events, + ze_event_pool_handle_t ipc_pool = {}, + size_t ipc_event_idx = 0); + +void add_handle_exchange(ccl_sched* sched, + ccl_comm* comm, + const std::vector& in_buffers, + int skip_rank = ccl_comm::invalid_rank, + ze_event_pool_handle_t pool = nullptr, + size_t event_idx = 0); + +} // namespace ccl diff --git a/src/coll/selection/selection.cpp b/src/coll/selection/selection.cpp index a90c83600..48ca3d811 100644 --- a/src/coll/selection/selection.cpp +++ b/src/coll/selection/selection.cpp @@ -14,8 +14,53 @@ limitations under the License. */ #include "coll/selection/selection.hpp" +#include "common/comm/comm.hpp" #include "common/global/global.hpp" +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) +#include +#include "common/utils/sycl_utils.hpp" +#include "sched/entry/ze/ze_primitives.hpp" +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE + +std::string to_string(const ccl_selector_param& param) { + std::stringstream ss; + + ss << "{ " + << "coll: " << ccl_coll_type_to_str(param.ctype) << ", count: " << param.count + << ", dt: " << ccl::global_data::get().dtypes->name(param.dtype); + + if (param.comm) { + ss << ", comm: { rank: " << param.comm->rank() << ", size: " << param.comm->size() << " }"; + } + + if (param.stream) { + ss << ", stream: " << param.stream->to_string(); + } + + if (param.buf) { + ss << ", buf: " << param.buf; + } + + if (param.is_vector_buf) { + ss << ", vector_buf"; + } + +#ifdef CCL_ENABLE_SYCL + if (param.is_sycl_buf) { + ss << ", sycl_buf"; + } +#endif // CCL_ENABLE_SYCL + + if (param.hint_algo.has_value()) { + ss << ", hint_algo: " << param.hint_algo.value; + } + + ss << " }"; + + return ss.str(); +} + bool ccl_is_direct_algo(const ccl_selector_param& param) { bool res = false; @@ -49,22 +94,195 @@ bool ccl_is_direct_algo(const ccl_selector_param& param) { return res; } +namespace checkers { + +bool is_family1_card(const ccl_selector_param& param) { + if (param.stream) { + return param.stream->get_device_family() == ccl::device_family::family1; + } + return false; +} + +bool is_coll_supported(std::initializer_list colls, ccl_coll_type value) { + return std::find(colls.begin(), colls.end(), value) != colls.end(); +} + +bool is_sycl_buf(const ccl_selector_param& param) { +#ifdef CCL_ENABLE_SYCL + return param.is_sycl_buf; +#endif // CCL_ENABLE_SYCL + return false; +} + +bool is_device_buf(const ccl_selector_param& param) { +#ifdef CCL_ENABLE_SYCL + if (param.buf && param.stream) { + auto ctx = param.stream->get_native_stream().get_context(); + return sycl::get_pointer_type(param.buf, ctx) == sycl::usm::alloc::device; + } +#endif // CCL_ENABLE_SYCL + return true; +} + +bool is_l0_backend(const ccl_selector_param& param) { +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + if (param.stream) { + return param.stream->get_backend() == ccl::utils::get_level_zero_backend(); + } +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE + return false; +} + +bool is_gpu_stream(const ccl_selector_param& param) { + if (param.stream) { + return param.stream->is_gpu(); + } + return false; +} + +bool is_single_node(const ccl_selector_param& param) { + size_t local_proc_count = ccl::global_data::get().executor->get_local_proc_count(); + return static_cast(param.comm->size()) == local_proc_count; +} + +bool is_single_card(const ccl_selector_param& param) { + return (param.comm->size() == 2) && is_single_node(param); +} + +} // namespace checkers + +#define RETURN_FALSE_IF(cond, ...) \ + do { \ + if (cond) { \ + LOG_DEBUG("selection checker: ", ##__VA_ARGS__); \ + return false; \ + } \ + } while (0) + static bool ccl_is_device_side_algo(ccl_coll_algo algo, const ccl_selector_param& param) { - if (param.ctype == ccl_coll_allreduce) { - return algo.allreduce == ccl_coll_allreduce_topo_ring; + CCL_THROW_IF_NOT(algo.has_value(), "empty algo value"); + + if (param.ctype == ccl_coll_allgatherv) { + return algo.allgatherv == ccl_coll_allgatherv_topo; } - else if (param.ctype == ccl_coll_reduce) { - return algo.reduce == ccl_coll_reduce_topo_ring; + else if (param.ctype == ccl_coll_allreduce) { + return algo.allreduce == ccl_coll_allreduce_topo; } else if (param.ctype == ccl_coll_bcast) { - return algo.bcast == ccl_coll_bcast_topo_ring; + return algo.bcast == ccl_coll_bcast_topo; + } + else if (param.ctype == ccl_coll_reduce) { + return algo.reduce == ccl_coll_reduce_topo; + } + else if (param.ctype == ccl_coll_reduce_scatter) { + return algo.reduce_scatter == ccl_coll_reduce_scatter_topo; } return false; } +bool ccl_is_device_side_algo(const ccl_selector_param& param) { +#ifndef CCL_ENABLE_SYCL + return false; +#endif // CCL_ENABLE_SYCL + + auto supported_colls = { ccl_coll_allgatherv, + ccl_coll_allreduce, + ccl_coll_bcast, + ccl_coll_reduce, + ccl_coll_reduce_scatter }; + RETURN_FALSE_IF(!checkers::is_coll_supported(supported_colls, param.ctype), + "coll ", + ccl_coll_type_to_str(param.ctype), + " is not supported"); + + ccl_coll_algo algo{}; + auto& selector = ccl::global_data::get().algorithm_selector; + + if (param.ctype == ccl_coll_allgatherv) { + algo.allgatherv = selector->get(param); + } + else if (param.ctype == ccl_coll_allreduce) { + algo.allreduce = selector->get(param); + } + else if (param.ctype == ccl_coll_bcast) { + algo.bcast = selector->get(param); + } + else if (param.ctype == ccl_coll_reduce) { + algo.reduce = selector->get(param); + } + else if (param.ctype == ccl_coll_reduce_scatter) { + algo.reduce_scatter = selector->get(param); + } + + return ccl_is_device_side_algo(algo, param); +} + +bool ccl_can_use_topo_algo(const ccl_selector_param& param) { + auto supported_colls = { ccl_coll_allgatherv, + ccl_coll_allreduce, + ccl_coll_bcast, + ccl_coll_reduce, + ccl_coll_reduce_scatter }; + RETURN_FALSE_IF(!checkers::is_coll_supported(supported_colls, param.ctype), + "coll is not supported"); + + size_t local_proc_count = ccl::global_data::get().executor->get_local_proc_count(); + int comm_size = param.comm->size(); + + RETURN_FALSE_IF(!checkers::is_gpu_stream(param), "non-gpu stream is not supported"); + RETURN_FALSE_IF(checkers::is_sycl_buf(param), "sycl buffer is not supported"); + RETURN_FALSE_IF(!checkers::is_device_buf(param), "non-device buffers is not supported"); + RETURN_FALSE_IF(!checkers::is_l0_backend(param), "non-l0 backend is not supported"); + + RETURN_FALSE_IF(ccl::global_data::env().enable_fusion, "fusion is not supported"); + RETURN_FALSE_IF(ccl::global_data::env().enable_unordered_coll, + "unordered coll is not supported"); + RETURN_FALSE_IF(ccl::global_data::env().priority_mode != ccl_priority_none, "wrong priority"); + RETURN_FALSE_IF(ccl::global_data::env().worker_count != 1, "unsupported count of workers"); + +#ifdef CCL_ENABLE_SYCL + if (!ccl::global_data::env().disable_ze_family_check) { + RETURN_FALSE_IF( + checkers::is_family1_card(param) && + (((!checkers::is_single_card(param) && + ((param.ctype == ccl_coll_allreduce || param.ctype == ccl_coll_reduce || + param.ctype == ccl_coll_allgatherv)))) || + (param.ctype == ccl_coll_reduce_scatter)), + "family1 multi-card for ", + ccl_coll_type_to_str(param.ctype), + " is not supported"); + } +#endif // CCL_ENABLE_SYCL + + RETURN_FALSE_IF((((param.ctype == ccl_coll_bcast) || (param.ctype == ccl_coll_reduce)) && + ((comm_size < 2) || (local_proc_count == 1))) || + ((param.ctype == ccl_coll_allreduce || param.ctype == ccl_coll_reduce) && + (comm_size <= 2) && (local_proc_count == 1)), + "unsupported comm size for ", + ccl_coll_type_to_str(param.ctype)); + + RETURN_FALSE_IF((param.ctype == ccl_coll_bcast || param.ctype == ccl_coll_reduce_scatter) && + !checkers::is_single_node(param), + "multi-node for ", + ccl_coll_type_to_str(param.ctype), + " is not supported"); + + RETURN_FALSE_IF(((param.ctype == ccl_coll_reduce) && (comm_size % local_proc_count != 0)), + "ppn must be equal"); + + RETURN_FALSE_IF(param.ctype == ccl_coll_allgatherv && !checkers::is_single_card(param) && + comm_size % local_proc_count != 0, + "ppn must be equal"); + + RETURN_FALSE_IF(!checkers::is_single_card(param) && !checkers::is_single_node(param) && + (local_proc_count % 2 != 0), + "odd proc count per node is not supported"); + return true; +} + bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param) { - // A regular type, so we don't need to check for an additional support + // regular datatype, don't need to check for an additional support if (param.dtype.idx() != ccl::datatype::bfloat16 && param.dtype.idx() != ccl::datatype::float16) { return true; @@ -74,10 +292,10 @@ bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param) { bool device_side_algo = ccl_is_device_side_algo(algo, param); - // Algorithms running on GPU device support both fp16 and bf16, so we don't need to require their - // support on the host. + // algorithms running on device side support fp16 and bf16 both + // so we don't need to require their support on the host if (!device_side_algo) { - if (param.dtype.idx() == ccl::datatype::bfloat16) { + if (param.dtype == ccl::datatype::bfloat16) { bool bf16_hw_support = ccl::global_data::env().bf16_impl_type != ccl_bf16_no_hardware_support; bool bf16_compiler_support = @@ -94,7 +312,7 @@ bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param) { bf16_compiler_support); } } - else if (param.dtype.idx() == ccl::datatype::float16) { + else if (param.dtype == ccl::datatype::float16) { bool fp16_hw_support = ccl::global_data::env().fp16_impl_type != ccl_fp16_no_hardware_support; bool fp16_compiler_support = @@ -115,71 +333,3 @@ bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param) { return can_use; } - -bool ccl_is_topo_ring_algo(const ccl_selector_param& param) { -#ifndef CCL_ENABLE_SYCL - return false; -#endif // CCL_ENABLE_SYCL - - if ((param.ctype != ccl_coll_allreduce) && (param.ctype != ccl_coll_bcast) && - (param.ctype != ccl_coll_reduce)) { - return false; - } - - bool res = false; - - auto& selector = ccl::global_data::get().algorithm_selector; - - if (param.ctype == ccl_coll_allreduce) { - res = (selector->get(param) == ccl_coll_allreduce_topo_ring); - } - else if (param.ctype == ccl_coll_bcast) { - res = (selector->get(param) == ccl_coll_bcast_topo_ring); - } - else if (param.ctype == ccl_coll_reduce) { - res = (selector->get(param) == ccl_coll_reduce_topo_ring); - } - - return res; -} - -bool ccl_can_use_topo_ring_algo(const ccl_selector_param& param) { - if ((param.ctype != ccl_coll_allreduce) && (param.ctype != ccl_coll_bcast) && - (param.ctype != ccl_coll_reduce)) { - return false; - } - - bool is_sycl_buf = false; - bool is_device_buf = true; - bool is_l0_backend = false; - - size_t local_proc_count = ccl::global_data::get().executor->get_local_proc_count(); - -#ifdef CCL_ENABLE_SYCL - is_sycl_buf = param.is_sycl_buf; - if (param.buf && param.stream) { - auto ctx = param.stream->get_native_stream().get_context(); - is_device_buf = - (sycl::get_pointer_type(param.buf, ctx) == sycl::usm::alloc::device) ? true : false; - } -#ifdef MULTI_GPU_SUPPORT - if (param.stream && param.stream->get_backend() == sycl::backend::level_zero) { - is_l0_backend = true; - } -#endif // MULTI_GPU_SUPPORT -#endif // CCL_ENABLE_SYCL - - if ((param.comm->size() != 2 && param.comm->size() != 4) || - (param.comm->size() == 2 && param.comm->size() != static_cast(local_proc_count)) || - (param.comm->size() == 4 && local_proc_count != 2 && local_proc_count != 4) || - (param.comm->size() != 2 && (ccl::global_data::env().atl_transport == ccl_atl_mpi)) || - !param.stream || (param.stream->get_type() != stream_type::gpu) || is_sycl_buf || - !is_device_buf || !is_l0_backend || ccl::global_data::env().enable_fusion || - ccl::global_data::env().enable_unordered_coll || - (ccl::global_data::env().priority_mode != ccl_priority_none) || - (ccl::global_data::env().worker_count != 1)) { - return false; - } - - return true; -} diff --git a/src/coll/selection/selection.hpp b/src/coll/selection/selection.hpp index 2a8fe2a28..9f3fc34dc 100644 --- a/src/coll/selection/selection.hpp +++ b/src/coll/selection/selection.hpp @@ -17,8 +17,9 @@ #include "coll/selection/selector_wrapper.hpp" -bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param); - bool ccl_is_direct_algo(const ccl_selector_param& param); -bool ccl_is_topo_ring_algo(const ccl_selector_param& param); -bool ccl_can_use_topo_ring_algo(const ccl_selector_param& param); +bool ccl_is_device_side_algo(const ccl_selector_param& param); + +bool ccl_can_use_topo_algo(const ccl_selector_param& param); + +bool ccl_can_use_datatype(ccl_coll_algo algo, const ccl_selector_param& param); diff --git a/src/coll/selection/selector_allgatherv.cpp b/src/coll/selection/selector_allgatherv.cpp index 28e5ebaa6..92311364d 100644 --- a/src/coll/selection/selector_allgatherv.cpp +++ b/src/coll/selection/selector_allgatherv.cpp @@ -15,6 +15,8 @@ */ #include "coll/selection/selection.hpp" +#include + template <> std::map ccl_algorithm_selector_helper::algo_names = { @@ -22,7 +24,8 @@ std::map std::make_pair(ccl_coll_allgatherv_naive, "naive"), std::make_pair(ccl_coll_allgatherv_ring, "ring"), std::make_pair(ccl_coll_allgatherv_flat, "flat"), - std::make_pair(ccl_coll_allgatherv_multi_bcast, "multi_bcast") + std::make_pair(ccl_coll_allgatherv_multi_bcast, "multi_bcast"), + std::make_pair(ccl_coll_allgatherv_topo, "topo") }; ccl_algorithm_selector::ccl_algorithm_selector() { @@ -33,8 +36,9 @@ ccl_algorithm_selector::ccl_algorithm_selector() { CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allgatherv_ring); } - else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) + else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allgatherv_direct); + } insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allgatherv_flat); } @@ -46,15 +50,21 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (param.is_vector_buf && algo != ccl_coll_allgatherv_flat && - algo != ccl_coll_allgatherv_multi_bcast) + if (algo == ccl_coll_allgatherv_topo && !ccl_can_use_topo_algo(param)) { can_use = false; - else if (ccl::global_data::env().atl_transport == ccl_atl_mpi && - algo == ccl_coll_allgatherv_multi_bcast) + } + else if (param.is_vector_buf && algo != ccl_coll_allgatherv_flat && + algo != ccl_coll_allgatherv_multi_bcast) { can_use = false; + } + else if (algo == ccl_coll_allgatherv_multi_bcast && + ccl::global_data::env().atl_transport == ccl_atl_mpi) { + can_use = false; + } else if (algo == ccl_coll_allgatherv_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + ccl::global_data::env().atl_transport == ccl_atl_ofi) { can_use = false; + } return can_use; } @@ -63,11 +73,11 @@ CCL_SELECTION_DEFINE_HELPER_METHODS(ccl_coll_allgatherv_algo, ccl_coll_allgatherv, ccl::global_data::env().allgatherv_algo_raw, ({ - CCL_ASSERT(param.recv_counts); - size_t count = 0; - for (int idx = 0; idx < param.comm->size(); idx++) { - count += param.recv_counts[idx]; - } + CCL_THROW_IF_NOT(param.recv_counts); + size_t count = + std::accumulate(param.recv_counts, + param.recv_counts + param.comm->size(), + 0); count /= param.comm->size(); count; })); diff --git a/src/coll/selection/selector_allreduce.cpp b/src/coll/selection/selector_allreduce.cpp index 2ca2aa831..1c7dc96d7 100644 --- a/src/coll/selection/selector_allreduce.cpp +++ b/src/coll/selection/selector_allreduce.cpp @@ -20,33 +20,42 @@ std::map ccl_algorithm_selector_helper::algo_names = { std::make_pair(ccl_coll_allreduce_direct, "direct"), std::make_pair(ccl_coll_allreduce_rabenseifner, "rabenseifner"), - std::make_pair(ccl_coll_allreduce_starlike, "starlike"), + std::make_pair(ccl_coll_allreduce_nreduce, "nreduce"), std::make_pair(ccl_coll_allreduce_ring, "ring"), std::make_pair(ccl_coll_allreduce_ring_rma, "ring_rma"), std::make_pair(ccl_coll_allreduce_double_tree, "double_tree"), std::make_pair(ccl_coll_allreduce_recursive_doubling, "recursive_doubling"), std::make_pair(ccl_coll_allreduce_2d, "2d"), - std::make_pair(ccl_coll_allreduce_topo_ring, "topo_ring") + std::make_pair(ccl_coll_allreduce_topo, "topo"), }; ccl_algorithm_selector::ccl_algorithm_selector() { -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) - insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_topo_ring); -#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_topo); + if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { + insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_ring); + insert( + fallback_table, 0, CCL_ALLREDUCE_SHORT_MSG_SIZE, ccl_coll_allreduce_recursive_doubling); + } + else { + insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_direct); + } +#else // CCL_ENABLE_SYCL && CCL_ENABLE_ZE if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_ring); insert(main_table, 0, CCL_ALLREDUCE_SHORT_MSG_SIZE, ccl_coll_allreduce_recursive_doubling); insert(main_table, CCL_ALLREDUCE_SHORT_MSG_SIZE + 1, CCL_ALLREDUCE_MEDIUM_MSG_SIZE, - ccl_coll_allreduce_starlike); + ccl_coll_allreduce_nreduce); } - else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) + else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_direct); -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT + } insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_allreduce_ring); insert(fallback_table, 0, CCL_ALLREDUCE_SHORT_MSG_SIZE, ccl_coll_allreduce_recursive_doubling); +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE } template <> @@ -63,9 +72,9 @@ bool ccl_algorithm_selector_helper::can_use( if (algo == ccl_coll_allreduce_rabenseifner && static_cast(param.count) < param.comm->pof2()) can_use = false; - else if (algo == ccl_coll_allreduce_ring_rma && !atl_wrapper::attr.out.enable_rma) + else if (algo == ccl_coll_allreduce_ring_rma && !atl_base_comm::attr.out.enable_rma) can_use = false; - else if (algo == ccl_coll_allreduce_starlike && !(param.count / param.comm->size())) + else if (algo == ccl_coll_allreduce_nreduce && !(param.count / param.comm->size())) can_use = false; else if (algo == ccl_coll_allreduce_2d && (ccl::global_data::env().atl_transport == ccl_atl_mpi)) @@ -73,7 +82,7 @@ bool ccl_algorithm_selector_helper::can_use( else if (algo == ccl_coll_allreduce_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; - else if (algo == ccl_coll_allreduce_topo_ring && !ccl_can_use_topo_ring_algo(param)) + else if (algo == ccl_coll_allreduce_topo && !ccl_can_use_topo_algo(param)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_bcast.cpp b/src/coll/selection/selector_bcast.cpp index 786bc22e1..604e1c54b 100644 --- a/src/coll/selection/selector_bcast.cpp +++ b/src/coll/selection/selector_bcast.cpp @@ -22,13 +22,13 @@ std::map std::make_pair(ccl_coll_bcast_ring, "ring"), std::make_pair(ccl_coll_bcast_double_tree, "double_tree"), std::make_pair(ccl_coll_bcast_naive, "naive"), - std::make_pair(ccl_coll_bcast_topo_ring, "topo_ring") + std::make_pair(ccl_coll_bcast_topo, "topo") }; ccl_algorithm_selector::ccl_algorithm_selector() { -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) - insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_topo_ring); -#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_topo); +#else // CCL_ENABLE_SYCL && CCL_ENABLE_ZE if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_naive); insert(main_table, 0, CCL_BCAST_SHORT_MSG_SIZE, ccl_coll_bcast_double_tree); @@ -36,7 +36,7 @@ ccl_algorithm_selector::ccl_algorithm_selector() { else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_direct); } -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_bcast_naive); } @@ -57,10 +57,12 @@ bool ccl_algorithm_selector_helper::can_use( can_use = false; } else if (algo == ccl_coll_bcast_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + (ccl::global_data::env().atl_transport == ccl_atl_ofi)) { can_use = false; - else if (algo == ccl_coll_bcast_topo_ring && !ccl_can_use_topo_ring_algo(param)) + } + else if (algo == ccl_coll_bcast_topo && !ccl_can_use_topo_algo(param)) { can_use = false; + } return can_use; } diff --git a/src/coll/selection/selector_impl.hpp b/src/coll/selection/selector_impl.hpp index 92bbe443f..6952d4ac4 100644 --- a/src/coll/selection/selector_impl.hpp +++ b/src/coll/selection/selector_impl.hpp @@ -29,6 +29,8 @@ #define CCL_SELECTION_ALGO_DELIMETER ':' #define CCL_SELECTION_SIZE_DELIMETER '-' +std::string to_string(const ccl_selector_param& param); + template void ccl_selection_unpack_elem(size_t& size, algo_group_type& algo, @@ -74,7 +76,7 @@ void ccl_algorithm_selector_base::init() { try { if (!std::getline(block_stream, algo_name_str, CCL_SELECTION_ALGO_DELIMETER)) CCL_THROW( - "can't parse algorithm name from string: ", str_to_parse, ", block: ", block); + "can not parse algorithm name from string: ", str_to_parse, ", block: ", block); } catch (const std::istream::failure& e) { LOG_ERROR("exception happened: ", @@ -85,7 +87,8 @@ void ccl_algorithm_selector_base::init() { block_stream.eof(), "\nbadbit: ", block_stream.bad()); - CCL_THROW("can't parse algorithm name from string: ", str_to_parse, ", block: ", block); + CCL_THROW( + "can not parse algorithm name from string: ", str_to_parse, ", block: ", block); } LOG_TRACE("block ", block, ", algo_name_str ", algo_name_str); @@ -103,7 +106,7 @@ void ccl_algorithm_selector_base::init() { block_stream.str(block.substr(algo_name_str.length() + 1)); if (!std::getline(block_stream, size_str, CCL_SELECTION_SIZE_DELIMETER)) CCL_THROW( - "can't parse left size from string: ", str_to_parse, ", block: ", block); + "can not parse left size from string: ", str_to_parse, ", block: ", block); if (!size_str.compare(CCL_SELECTION_MAX_COLL_SIZE_STR)) left_size = CCL_SELECTION_MAX_COLL_SIZE; else @@ -111,14 +114,17 @@ void ccl_algorithm_selector_base::init() { } catch (const std::exception& e) { LOG_ERROR("exception happened during left size parsing: ", e.what()); - CCL_THROW("can't parse left size from string: ", str_to_parse, ", block: ", block); + CCL_THROW( + "can not parse left size from string: ", str_to_parse, ", block: ", block); } try { block_stream.str(block.substr(algo_name_str.length() + size_str.length() + 2)); if (!std::getline(block_stream, size_str, CCL_SELECTION_SIZE_DELIMETER)) - CCL_THROW( - "can't parse second size from string: ", str_to_parse, ", block: ", block); + CCL_THROW("can not parse second size from string: ", + str_to_parse, + ", block: ", + block); if (!size_str.compare(CCL_SELECTION_MAX_COLL_SIZE_STR)) right_size = CCL_SELECTION_MAX_COLL_SIZE; else @@ -126,7 +132,8 @@ void ccl_algorithm_selector_base::init() { } catch (const std::exception& e) { LOG_ERROR("exception happened during right size parsing: ", e.what()); - CCL_THROW("can't parse right size from string: ", str_to_parse, ", block: ", block); + CCL_THROW( + "can not parse right size from string: ", str_to_parse, ", block: ", block); } LOG_TRACE("algo ", algo_name_str, ", left ", left_size, ", right ", right_size); @@ -238,13 +245,15 @@ algo_group_type ccl_algorithm_selector_base::get( algo_group_type elem_algo; ccl_selection_border_type elem_border; + LOG_DEBUG("param: ", ::to_string(param)); + size_t count = ccl_algorithm_selector_helper::get_count(param); if (param.hint_algo.has_value()) { elem_algo = static_cast(param.hint_algo.value); if (!ccl_algorithm_selector_helper::can_use( elem_algo, param, main_table)) { - LOG_DEBUG("can't select hint algorithm: coll ", + LOG_DEBUG("can not select hint algorithm: coll ", ccl_coll_type_to_str(param.ctype), ", count ", count, @@ -269,16 +278,23 @@ algo_group_type ccl_algorithm_selector_base::get( if (lower_bound == main_table.end() || !ccl_algorithm_selector_helper::can_use(elem_algo, param, main_table)) { + CCL_THROW_IF_NOT(ccl::global_data::env().enable_algo_fallback, + "can not select algo from main table and fallback is disabled", + ", coll ", + ccl_coll_type_to_str(param.ctype), + ", count ", + count); + lower_bound = fallback_table.lower_bound(size); ccl_selection_unpack_elem(elem_size, elem_algo, elem_border, lower_bound, fallback_table); CCL_THROW_IF_NOT(lower_bound != fallback_table.end(), - "can't select algorithm: coll ", + "can not select algorithm: coll ", ccl_coll_type_to_str(param.ctype), ", count ", count); CCL_THROW_IF_NOT(ccl_algorithm_selector_helper::can_use( elem_algo, param, fallback_table), - "can't select algorithm in fallback_table: coll ", + "can not select algorithm in fallback_table: coll ", ccl_coll_type_to_str(param.ctype)); } diff --git a/src/coll/selection/selector_reduce.cpp b/src/coll/selection/selector_reduce.cpp index 4c13ea035..5e91b1e7c 100644 --- a/src/coll/selection/selector_reduce.cpp +++ b/src/coll/selection/selector_reduce.cpp @@ -22,20 +22,20 @@ std::map std::make_pair(ccl_coll_reduce_rabenseifner, "rabenseifner"), std::make_pair(ccl_coll_reduce_tree, "tree"), std::make_pair(ccl_coll_reduce_double_tree, "double_tree"), - std::make_pair(ccl_coll_reduce_topo_ring, "topo_ring") + std::make_pair(ccl_coll_reduce_topo, "topo") }; ccl_algorithm_selector::ccl_algorithm_selector() { -#if defined(CCL_ENABLE_SYCL) && defined(MULTI_GPU_SUPPORT) - insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_topo_ring); -#else // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_topo); +#else // CCL_ENABLE_SYCL && CCL_ENABLE_ZE if (ccl::global_data::env().atl_transport == ccl_atl_ofi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_tree); } else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_direct); } -#endif // CCL_ENABLE_SYCL && MULTI_GPU_SUPPORT +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_tree); } @@ -56,7 +56,7 @@ bool ccl_algorithm_selector_helper::can_use( else if (algo == ccl_coll_reduce_direct && (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; - else if (algo == ccl_coll_reduce_topo_ring && !ccl_can_use_topo_ring_algo(param)) + else if (algo == ccl_coll_reduce_topo && !ccl_can_use_topo_algo(param)) can_use = false; return can_use; diff --git a/src/coll/selection/selector_reduce_scatter.cpp b/src/coll/selection/selector_reduce_scatter.cpp index 3d8f67e01..e68b369dc 100644 --- a/src/coll/selection/selector_reduce_scatter.cpp +++ b/src/coll/selection/selector_reduce_scatter.cpp @@ -20,13 +20,18 @@ std::map ccl_algorithm_selector_helper::algo_names = { std::make_pair(ccl_coll_reduce_scatter_direct, "direct"), std::make_pair(ccl_coll_reduce_scatter_ring, "ring"), + std::make_pair(ccl_coll_reduce_scatter_topo, "topo"), }; ccl_algorithm_selector::ccl_algorithm_selector() { +#if defined(CCL_ENABLE_SYCL) && defined(CCL_ENABLE_ZE) + insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_scatter_topo); +#else // CCL_ENABLE_SYCL && CCL_ENABLE_ZE if (ccl::global_data::env().atl_transport == ccl_atl_ofi) insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_scatter_ring); else if (ccl::global_data::env().atl_transport == ccl_atl_mpi) insert(main_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_scatter_direct); +#endif // CCL_ENABLE_SYCL && CCL_ENABLE_ZE insert(fallback_table, 0, CCL_SELECTION_MAX_COLL_SIZE, ccl_coll_reduce_scatter_ring); } @@ -38,8 +43,11 @@ bool ccl_algorithm_selector_helper::can_use( const ccl_selection_table_t& table) { bool can_use = true; - if (algo == ccl_coll_reduce_scatter_direct && - (ccl::global_data::env().atl_transport == ccl_atl_ofi)) + if (algo == ccl_coll_reduce_scatter_topo && !ccl_can_use_topo_algo(param)) { + can_use = false; + } + else if (algo == ccl_coll_reduce_scatter_direct && + (ccl::global_data::env().atl_transport == ccl_atl_ofi)) can_use = false; return can_use; diff --git a/src/common/comm/atl_tag.cpp b/src/common/comm/atl_tag.cpp index 2bca7ba61..a60dd25e6 100644 --- a/src/common/comm/atl_tag.cpp +++ b/src/common/comm/atl_tag.cpp @@ -16,12 +16,12 @@ #include "common/comm/atl_tag.hpp" #include "exec/exec.hpp" -void ccl_atl_tag::print() { - LOG_INFO("atl-tag:"); - LOG_INFO(" bits: ", tag_bits); - LOG_INFO(" max: ", max_tag); - LOG_INFO(" mask: ", max_tag_mask); - LOG_INFO(" pof2: ", ccl_pof2(max_tag)); +std::string ccl_atl_tag::to_string() const { + std::stringstream ss; + ss << "{ " + << "bits: " << tag_bits << ", max: " << max_tag << ", mask: " << max_tag_mask + << ", pof2: " << ccl_pof2(max_tag) << " }"; + return ss.str(); } uint64_t ccl_atl_tag::create(int rank, diff --git a/src/common/comm/atl_tag.hpp b/src/common/comm/atl_tag.hpp index 4c9a46cfc..ac5e20b8a 100644 --- a/src/common/comm/atl_tag.hpp +++ b/src/common/comm/atl_tag.hpp @@ -41,7 +41,7 @@ class ccl_atl_tag { ~ccl_atl_tag() = default; - void print(); + std::string to_string() const; /** * Generates the tag to be used by ATL communication operations diff --git a/src/common/comm/comm.cpp b/src/common/comm/comm.cpp index fe6e8062e..98f6832a1 100644 --- a/src/common/comm/comm.cpp +++ b/src/common/comm/comm.cpp @@ -13,138 +13,267 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "atl/atl_base_comm.hpp" #include "atl/util/pm/pmi_resizable_rt/pmi_resizable/kvs/users_kvs.h" #include "exec/exec.hpp" +#include "coll/coll.hpp" +#include "coll/coll_common_attributes.hpp" +#include "coll/ccl_allgather_op_attr.hpp" #include "common/comm/comm.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" +#include "common/comm/comm_impl.hpp" #include "common/global/global.hpp" +#include "common/event/impls/host_event.hpp" +#include "common/request/request.hpp" #include "sched/sched.hpp" #include "oneapi/ccl/types.hpp" #include "oneapi/ccl/kvs.hpp" +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" +#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h" -void ccl_comm::allocate_resources() { - if (ccl::global_data::env().enable_unordered_coll) { - unordered_coll_manager = - std::unique_ptr(new ccl_unordered_coll_manager(*this)); +ccl_comm_internal::ccl_comm_internal(int rank, int size, std::shared_ptr atl) + : ccl_comm_internal(rank, size, atl->get_rank2rank_map(), atl) {} + +ccl_comm_internal::ccl_comm_internal(int rank, + int size, + ccl_rank2rank_map&& rank_map, + std::shared_ptr atl) + : atl(atl), + m_local2global_map(std::move(rank_map)), + m_dtree(size, rank) { + reset(rank, size); +} + +ccl_comm_internal::ccl_comm_internal(const std::vector& local_ranks, + int comm_size, + std::shared_ptr kvs_instance) + : m_local2global_map(), + m_dtree(local_ranks.size(), comm_size) { + std::shared_ptr kvs_wrapper(new users_kvs(kvs_instance)); + + atl = atl_comm_manager::create_comm(comm_size, local_ranks, kvs_wrapper); + + reset(atl->get_rank(), atl->get_size()); +} + +//TODO: will fix it after OFI refactoring +int ccl_comm::get_global_rank(int rank, bool only_global) const { + // TODO: move map to ccl_comm? + const auto& local2global_map = comm_impl->get_local2global_map(); + + if (local2global_map.empty() || !only_global) { + // global comm and its copies do not have entries in the map + return rank; } - auto& env_object = ccl::global_data::env(); + CCL_THROW_IF_NOT((int)local2global_map.size() > rank, + "no rank ", + rank, + " was found in comm ", + this, + ", id ", + id()); + int global_rank = local2global_map[rank]; + LOG_DEBUG("comm ", this, ", id ", id(), ", map rank ", rank, " to global ", global_rank); + return global_rank; +} - allreduce_2d_builder = std::unique_ptr(new ccl_allreduce_2d_builder( - (env_object.allreduce_2d_base_size != CCL_ENV_SIZET_NOT_SPECIFIED) - ? env_object.allreduce_2d_base_size - : ccl::global_data::get().executor->get_local_proc_count(), - env_object.allreduce_2d_switch_dims, - this)); +int ccl_comm::get_rank_from_global(int global_rank) const { + const auto& local2global_map = comm_impl->get_local2global_map(); + + if (local2global_map.empty()) { + // global comm and its copies do not have entries in the map + return global_rank; + } + + int rank = ccl_comm::invalid_rank; + + for (size_t i = 0; i < local2global_map.size(); ++i) { + if (local2global_map[i] == global_rank) { + rank = static_cast(i); + break; + } + } + + CCL_THROW_IF_NOT(rank != ccl_comm::invalid_rank, "can't find rank"); + + return rank; +} - env_object.print(m_rank); +using ccl::preview::create_comm_split_attr; + +ccl_comm::ccl_comm() + : device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()) {} + +ccl_comm::ccl_comm(int size, ccl::shared_ptr_class kvs) + : device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()), + comm_rank(0), + comm_size(size), + next_sched_id_internal(ccl_comm_internal::max_sched_count / 2), + next_sched_id_external(0) { + if (size <= 0) { + throw ccl::exception("Incorrect size value when creating a host communicator"); + } +} + +ccl_comm::ccl_comm(int size, int rank, ccl::shared_ptr_class kvs) + : ccl_comm(atl_comm_manager::create_comm(size, { rank }, kvs)) {} + +ccl_comm::ccl_comm(ccl::unified_device_type&& d, + ccl::unified_context_type&& c, + std::shared_ptr atl) + : device(std::move(d)), + context(std::move(c)), + comm_attr(create_comm_split_attr()), + comm_rank(atl->get_rank()), + comm_size(atl->get_size()), + comm_id(std::unique_ptr( + new ccl_comm_id_storage::comm_id(ccl::global_data::get().comm_ids->acquire()))), + next_sched_id_internal(ccl_comm_internal::max_sched_count / 2), + next_sched_id_external(0) { + int rank = atl->get_rank(); + int size = atl->get_size(); + + if (rank > size || size <= 0) { + throw ccl::exception("incorrect rank or size when creating \ + a host communicator: rank: " + + std::to_string(rank) + ", size: " + std::to_string(size)); + } + + LOG_DEBUG("ctor"); + + comm_impl = std::unique_ptr(new ccl_comm_internal(rank, size, atl)); + + allocate_resources(); + create_sub_comms(atl); } +ccl_comm::ccl_comm(std::shared_ptr atl) + : ccl_comm(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value), + {}, + atl) {} + ccl_comm::ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, - std::shared_ptr atl, + std::shared_ptr atl, bool share_resources, - ccl::host_communicator* host_comm) - : ccl_comm(rank, - size, - std::move(id), - ccl_rank2rank_map{}, - atl, - share_resources, - host_comm) {} + bool is_sub_communicator) + : comm_impl(std::make_shared(rank, size, atl->get_rank2rank_map(), atl)), + device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()), + comm_rank(rank), + comm_size(size), + comm_id(std::unique_ptr( + new ccl_comm_id_storage::comm_id(std::move(id)))), + next_sched_id_internal(ccl_comm_internal::max_sched_count / 2), + next_sched_id_external(0) { + if (!share_resources) { + allocate_resources(); + } + + if (!is_sub_communicator) { + create_sub_comms(comm_impl.get()->atl); + } +} ccl_comm::ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, ccl_rank2rank_map&& rank_map, - std::shared_ptr atl, + std::shared_ptr atl, bool share_resources, - ccl::host_communicator* host_comm) - : atl(atl), - m_id(std::move(id)), - m_local2global_map(std::move(rank_map)), - m_dtree(size, rank), - thread_number(1), - on_process_ranks_number(1), - host_comm(host_comm) { - reset(rank, size); - + bool is_sub_communicator) + : comm_impl(std::make_shared(rank, size, std::move(rank_map), atl)), + device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + comm_attr(create_comm_split_attr()), + comm_rank(rank), + comm_size(size), + comm_id(std::unique_ptr( + new ccl_comm_id_storage::comm_id(std::move(id)))), + next_sched_id_internal(ccl_comm_internal::max_sched_count / 2), + next_sched_id_external(0) { if (!share_resources) { allocate_resources(); } -} -//TODO non-implemented -//TODO rude simulation of multi-thread barrier -static std::atomic thread_counter{}; -static std::atomic thread_ranks_counter{}; -void ccl_comm::ccl_comm_reset_thread_barrier() { - // recharge counters again - thread_counter.store(0); - thread_ranks_counter.store(0); + if (!is_sub_communicator) { + create_sub_comms(get_atl_comm()); + } } -ccl_comm::ccl_comm(const std::vector& local_ranks, - int comm_size, - std::shared_ptr kvs_instance, - ccl_comm_id_storage::comm_id&& id, - bool share_resources, - ccl::host_communicator* host_comm) - : m_id(std::move(id)), - m_local2global_map(), - m_dtree(local_ranks.size(), comm_size), - host_comm(host_comm) { - std::shared_ptr kvs_wrapper(new users_kvs(kvs_instance)); - - atl = std::shared_ptr(new atl_wrapper(comm_size, local_ranks, kvs_wrapper)); +ccl_comm::ccl_comm(const ccl_comm& src, ccl_comm_id_storage::comm_id&& id) + : comm_impl(src.comm_impl), + device(ccl::device_index_type(ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value)), + r2r_comm(src.r2r_comm), + node_comm(src.node_comm), + even_comm(src.even_comm), + pair_comm(src.pair_comm), + comm_attr(create_comm_split_attr()), + comm_rank(src.rank()), + comm_size(src.size()), + comm_id(std::unique_ptr( + new ccl_comm_id_storage::comm_id(std::move(id)))), + next_sched_id_internal(ccl_comm_internal::max_sched_count / 2), + next_sched_id_external(0) {} - thread_number = atl->get_threads_per_process(); - on_process_ranks_number = atl->get_ranks_per_process(); - - reset(atl->get_rank(), atl->get_size()); +ccl::device_index_type ccl_comm::get_device_path() const { + return ccl::device_index_type{ ccl::unused_index_value, + ccl::unused_index_value, + ccl::unused_index_value }; +} - if (!share_resources) { - allocate_resources(); - } +ccl::communicator_interface::device_t ccl_comm::get_device() const { + CCL_THROW(std::string(__FUNCTION__) + " is not applicable for " + traits::name()); + static ccl::communicator_interface::device_t empty; + return empty; } -ccl_comm* ccl_comm::create_with_colors(const std::vector& colors, - ccl_comm_id_storage* comm_ids, - const ccl_comm* parent_comm, - bool share_resources) { - ccl_rank2rank_map rank_map; - int new_comm_size = 0; - int new_comm_rank = 0; - int color = colors[parent_comm->rank()]; - - for (int i = 0; i < parent_comm->size(); ++i) { - if (colors[i] == color) { - LOG_DEBUG("map local rank ", new_comm_size, " to global ", i); - rank_map.emplace_back(i); - ++new_comm_size; - if (i < parent_comm->rank()) { - ++new_comm_rank; - } - } - } +ccl::communicator_interface::context_t ccl_comm::get_context() const { + CCL_THROW(std::string(__FUNCTION__) + " is not applicable for " + traits::name()); + static ccl::communicator_interface::context_t empty; + return empty; +} - if (new_comm_size == 0) { - throw ccl::exception(std::string("no colors matched to ") + std::to_string(color) + - " seems to be exchange issue"); - } +void ccl_comm::create_sub_comms(std::shared_ptr atl) { + ccl::global_data& data = ccl::global_data::get(); - if (new_comm_size == parent_comm->size()) { - // exact copy of the global communicator, use empty map - rank_map.clear(); - } + r2r_comm = std::shared_ptr( + this->create_with_color(atl->get_r2r_color(), data.comm_ids.get(), true)); + node_comm = std::shared_ptr( + this->create_with_color(atl->get_host_color(), data.comm_ids.get(), true)); + even_comm = std::shared_ptr(this->create_with_color( + atl->get_host_color() + atl->get_rank() % 2, data.comm_ids.get(), true)); + pair_comm = std::shared_ptr(this->create_with_color( + atl->get_host_color() + atl->get_rank() / 2, data.comm_ids.get(), true)); +} - ccl_comm* comm = new ccl_comm(new_comm_rank, - new_comm_size, +ccl_comm* ccl_comm::create_with_color(int color, + ccl_comm_id_storage* comm_ids, + bool share_resources) const { + std::shared_ptr atl_comm = get_atl_comm()->comm_split(color); + ccl_comm* comm = new ccl_comm(atl_comm->get_rank(), + atl_comm->get_size(), comm_ids->acquire(), - std::move(rank_map), - parent_comm->atl, - share_resources); + atl_comm->get_rank2rank_map(), + atl_comm, + share_resources, + true); LOG_DEBUG("new comm: color ", color, @@ -158,32 +287,306 @@ ccl_comm* ccl_comm::create_with_colors(const std::vector& colors, return comm; } -std::shared_ptr ccl_comm::clone_with_new_id(ccl_comm_id_storage::comm_id&& id) { - ccl_rank2rank_map rank_map{ m_local2global_map }; - return std::make_shared(m_rank, - m_size, - std::move(id), - std::move(rank_map), - atl, - true /*share_resources*/, - get_host_comm()); -} - -int ccl_comm::get_global_rank(int rank) const { - if (m_local2global_map.empty()) { - // global comm and its copies do not have entries in the map - return rank; +ccl::communicator_interface_ptr ccl_comm::split(const ccl::comm_split_attr& attr) { + if (!attr.is_valid()) { + CCL_THROW(std::string(__FUNCTION__) + + " - 'Color' split attribute for host communicator is not set"); } - CCL_THROW_IF_NOT((int)m_local2global_map.size() > rank, - "no rank ", - rank, - " was found in comm ", - this, - ", id ", - m_id.value()); - int global_rank = m_local2global_map[rank]; - LOG_DEBUG( - "comm , ", this, " id ", m_id.value(), ", map rank ", rank, " to global ", global_rank); - return global_rank; + ccl::global_data& data = ccl::global_data::get(); + auto new_comm = this->create_with_color( + attr.get(), data.comm_ids.get(), true); + + comm_attr = attr; + + return std::shared_ptr(new_comm); +} + +ccl::event ccl_comm::barrier(const ccl::stream::impl_value_t& stream, + const ccl::barrier_attr& attr, + const ccl::vector_class& deps) { + return barrier_impl(stream, attr, deps); +} + +ccl::event ccl_comm::barrier_impl(const ccl::stream::impl_value_t& stream, + const ccl::barrier_attr& attr, + const ccl::vector_class& deps) { + ccl_barrier_impl(this, stream.get(), deps); + return std::unique_ptr(new ccl::host_event_impl(nullptr)); +} + +/* allgatherv */ +ccl::event ccl_comm::allgatherv_impl(const void* send_buf, + size_t send_count, + void* recv_buf, + const ccl::vector_class& recv_counts, + ccl::datatype dtype, + const ccl::stream::impl_value_t& stream, + const ccl::allgatherv_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_allgatherv_impl(send_buf, + send_count, + recv_buf, + recv_counts.data(), + dtype, + attr, + this, + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +ccl::event ccl_comm::allgatherv_impl(const void* send_buf, + size_t send_count, + const ccl::vector_class& recv_bufs, + const ccl::vector_class& recv_counts, + ccl::datatype dtype, + const ccl::stream::impl_value_t& stream, + const ccl::allgatherv_attr& attr, + const ccl::vector_class& deps) { + ccl_coll_attr internal_attr(attr); + internal_attr.is_vector_buf = 1; + + ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(send_buf), + send_count, + (void*)(recv_bufs.data()), + recv_counts.data(), + dtype, + internal_attr, + this, + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +/* allreduce */ +ccl::event ccl_comm::allreduce_impl(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::allreduce_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_allreduce_impl( + send_buf, recv_buf, count, dtype, reduction, attr, this, get_stream_ptr(stream), deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +/* alltoall */ +ccl::event ccl_comm::alltoall_impl(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + const ccl::stream::impl_value_t& stream, + const ccl::alltoall_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_alltoall_impl( + send_buf, recv_buf, count, dtype, attr, this, get_stream_ptr(stream), deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +ccl::event ccl_comm::alltoall_impl(const ccl::vector_class& send_buf, + const ccl::vector_class& recv_buf, + size_t count, + ccl::datatype dtype, + const ccl::stream::impl_value_t& stream, + const ccl::alltoall_attr& attr, + const ccl::vector_class& deps) { + // TODO not implemented + CCL_THROW(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); + return {}; +} + +/* alltoallv */ +ccl::event ccl_comm::alltoallv_impl(const void* send_buf, + const ccl::vector_class& send_counts, + void* recv_buf, + const ccl::vector_class& recv_counts, + ccl::datatype dtype, + const ccl::stream::impl_value_t& stream, + const ccl::alltoallv_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_alltoallv_impl(send_buf, + send_counts.data(), + recv_buf, + recv_counts.data(), + dtype, + attr, + this, + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +ccl::event ccl_comm::alltoallv_impl(const ccl::vector_class& send_buf, + const ccl::vector_class& send_counts, + ccl::vector_class recv_buf, + const ccl::vector_class& recv_counts, + ccl::datatype dtype, + const ccl::stream::impl_value_t& stream, + const ccl::alltoallv_attr& attr, + const ccl::vector_class& dep) { + // TODO not implemented + CCL_THROW(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); + return {}; +} + +/* bcast */ +ccl::event ccl_comm::broadcast_impl(void* buf, + size_t count, + ccl::datatype dtype, + int root, + const ccl::stream::impl_value_t& stream, + const ccl::broadcast_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = + ccl_broadcast_impl(buf, count, dtype, root, attr, this, get_stream_ptr(stream), deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +/* reduce */ +ccl::event ccl_comm::reduce_impl(const void* send_buf, + void* recv_buf, + size_t count, + ccl::datatype dtype, + ccl::reduction reduction, + int root, + const ccl::stream::impl_value_t& stream, + const ccl::reduce_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_reduce_impl(send_buf, + recv_buf, + count, + dtype, + reduction, + root, + attr, + this, + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +/* reduce_scatter */ +ccl::event ccl_comm::reduce_scatter_impl(const void* send_buf, + void* recv_buf, + size_t recv_count, + ccl::datatype dtype, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::reduce_scatter_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_reduce_scatter_impl( + send_buf, recv_buf, recv_count, dtype, reduction, attr, this, get_stream_ptr(stream), deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +/* sparse_allreduce */ +ccl::event ccl_comm::sparse_allreduce_impl(const void* send_ind_buf, + size_t send_ind_count, + const void* send_val_buf, + size_t send_val_count, + void* recv_ind_buf, + size_t recv_ind_count, + void* recv_val_buf, + size_t recv_val_count, + ccl::datatype index_dtype, + ccl::datatype value_dtype, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { + ccl_request* req = ccl_sparse_allreduce_impl(send_ind_buf, + send_ind_count, + send_val_buf, + send_val_count, + recv_ind_buf, + recv_ind_count, + recv_val_buf, + recv_val_count, + index_dtype, + value_dtype, + reduction, + attr, + this, + get_stream_ptr(stream), + deps); + + return std::unique_ptr(new ccl::host_event_impl(req)); +} + +std::shared_ptr ccl_comm::get_atl_comm() const { + return comm_impl->atl; +} + +std::shared_ptr ccl_comm::get_r2r_comm() { + return r2r_comm; +} + +std::shared_ptr ccl_comm::get_node_comm() { + return node_comm; +} + +std::shared_ptr ccl_comm::get_pair_comm() { + return pair_comm; +} + +std::shared_ptr ccl_comm::get_even_comm() { + return even_comm; +} + +std::string ccl_comm::to_string() const { + std::stringstream ss; + ss << "{ rank: " << rank() << ", size: " << size() << ", id: " << id() << " }"; + return ss.str(); +} + +std::string ccl_comm::to_string_ext() const { + std::stringstream ss; + ss << "{\n"; + ss << " " << to_string() << "\n"; + ss << " r2r_comm: " << (r2r_comm ? r2r_comm->to_string() : "{}") << "\n"; + ss << " node_comm: " << (node_comm ? node_comm->to_string() : "{}") << "\n"; + ss << " even_comm: " << (even_comm ? even_comm->to_string() : "{}") << "\n"; + ss << " pair_comm: " << (pair_comm ? pair_comm->to_string() : "{}") << "\n"; + ss << "}"; + + return ss.str(); +} + +// NOTE: allocate_resources must be done on ccl_comm level, if it's called on ccl_comm_internal level +// the ccl_comm object that we need won't be fully constructed +void ccl_comm::allocate_resources() { + if (ccl::global_data::env().enable_unordered_coll) { + comm_impl->unordered_coll_manager.reset(new ccl_unordered_coll_manager(*this)); + } + + auto& env_object = ccl::global_data::env(); + + comm_impl->allreduce_2d_builder.reset(new ccl_allreduce_2d_builder( + (env_object.allreduce_2d_base_size != CCL_ENV_SIZET_NOT_SPECIFIED) + ? env_object.allreduce_2d_base_size + : ccl::global_data::get().executor->get_local_proc_count(), + env_object.allreduce_2d_switch_dims, + this)); + + env_object.print(rank()); } + +std::shared_ptr ccl_comm::clone_with_new_id(ccl_comm_id_storage::comm_id&& id) { + return std::shared_ptr(new ccl_comm(*this, std::move(id))); +} + +COMM_INTERFACE_COLL_INSTANTIATION(ccl_comm); +#ifdef CCL_ENABLE_SYCL +SYCL_COMM_INTERFACE_COLL_INSTANTIATION(ccl_comm); +#endif // CCL_ENABLE_SYCL diff --git a/src/common/comm/comm.hpp b/src/common/comm/comm.hpp index 77505c705..1356ec53e 100644 --- a/src/common/comm/comm.hpp +++ b/src/common/comm/comm.hpp @@ -17,53 +17,218 @@ #include #include - -#include "atl/atl_wrapper.h" +#include "atl/atl_base_comm.hpp" #include "coll/algorithms/allreduce/allreduce_2d.hpp" +#include "common/comm/communicator_traits.hpp" +#include "common/comm/comm_interface.hpp" #include "common/comm/comm_id_storage.hpp" #include "common/comm/atl_tag.hpp" #include "common/log/log.hpp" +#include "common/stream/stream.hpp" #include "common/utils/tree.hpp" #include "common/utils/utils.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/comm_split_attr_ids.hpp" +#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" +#include "oneapi/ccl/comm_split_attr.hpp" +#include "oneapi/ccl/types.hpp" +#include "oneapi/ccl/type_traits.hpp" +#include "oneapi/ccl/types_policy.hpp" +#include "oneapi/ccl/event.hpp" +#include "oneapi/ccl/coll_attr_ids.hpp" +#include "oneapi/ccl/coll_attr_ids_traits.hpp" +#include "oneapi/ccl/coll_attr.hpp" +#include "types_generator_defines.hpp" #include "unordered_coll/unordered_coll.hpp" // index = local_rank, value = global_rank using ccl_rank2rank_map = std::vector; +class ikvs_wrapper; + +inline ccl_stream* get_stream_ptr(const ccl::stream::impl_value_t& stream) { + if (stream.get() && stream->is_sycl_device_stream()) + return stream.get(); + else + return nullptr; +} + +using ccl_rank2rank_map = std::vector; + +class ccl_comm; namespace ccl { -class host_communicator; namespace v1 { class kvs_interface; } } // namespace ccl -class alignas(CACHELINE_SIZE) ccl_comm { +// The main purpose of the internal part is to hold shareable parts of ccl_comm which don't need to +// be copied/reset on ccl_comm's copy. +class alignas(CACHELINE_SIZE) ccl_comm_internal { public: - static constexpr int invalid_rank = -1; + static void ccl_comm_reset_thread_barrier(); + ccl_comm_internal() = delete; + ccl_comm_internal(const ccl_comm_internal& other) = delete; + ccl_comm_internal& operator=(const ccl_comm_internal& other) = delete; + + ccl_comm_internal(int rank, int size, std::shared_ptr atl); - ccl::host_communicator* get_host_comm() { - return host_comm; + ccl_comm_internal(int rank, + int size, + ccl_rank2rank_map&& ranks, + std::shared_ptr atl); + + //TODO non-implemented + //1) cluster_devices_count (devices 1000) -> (processes 10) + //2) blocking until all thread -> calls ccl_comm + //3) return 'thread_count' + + // ccl_comm( {0,1,2,3...}, 1000, kvs ) + // from 20 processes from ranks 0,1,2,3. Each rank contains 10 threads + // communicator: size in {20} and ranks in {0..19} + // communicator: return threads count in process {10} + // communicator: return devices counts per thread in process + ccl_comm_internal(const std::vector& local_ranks, + int comm_size, + std::shared_ptr kvs_instance); + + ~ccl_comm_internal() = default; + + int rank() const noexcept { + return m_rank; } - static void ccl_comm_reset_thread_barrier(); - ccl_comm() = delete; - ccl_comm(const ccl_comm& other) = delete; - ccl_comm& operator=(const ccl_comm& other) = delete; + int size() const noexcept { + return m_size; + } + + int pof2() const noexcept { + return m_pof2; + } + + const ccl_double_tree& dtree() const { + return m_dtree; + } + + void reset(int rank, int size) { + m_rank = rank; + m_size = size; + m_pof2 = ccl_pof2(m_size); + } + + const ccl_rank2rank_map& get_local2global_map() { + return m_local2global_map; + } + + /** + * Maximum available number of active communicators + */ + static constexpr ccl_sched_id_t max_comm_count = std::numeric_limits::max(); + /** + * Maximum value of schedule id in scope of the current communicator + */ + static constexpr ccl_sched_id_t max_sched_count = std::numeric_limits::max(); + std::shared_ptr atl; + std::unique_ptr unordered_coll_manager; + std::unique_ptr allreduce_2d_builder; + +private: + int m_rank; + int m_size; + int m_pof2; + + ccl_rank2rank_map m_local2global_map{}; + ccl_double_tree m_dtree; +}; + +class alignas(CACHELINE_SIZE) ccl_comm : public ccl::communicator_interface { +public: + using traits = ccl::host_communicator_traits; + + // traits + bool is_host() const noexcept override { + return traits::is_host(); + } + + bool is_cpu() const noexcept override { + return traits::is_cpu(); + } + + bool is_gpu() const noexcept override { + return traits::is_gpu(); + } + + bool is_accelerator() const noexcept override { + return traits::is_accelerator(); + } + + bool is_ready() const override { + return true; + } + + ccl::device_index_type get_device_path() const override; + ccl::communicator_interface::device_t get_device() const override; + ccl::communicator_interface::context_t get_context() const override; + + const ccl::comm_split_attr& get_comm_split_attr() const override { + return comm_attr; + } + + ccl::group_split_type get_topology_type() const override { + CCL_THROW(std::string(__FUNCTION__) + " is not applicable for " + traits::name()); + return ccl::group_split_type::undetermined; + } + + ccl::device_topology_type get_topology_class() const override { + CCL_THROW(std::string(__FUNCTION__) + " is not applicable for " + traits::name()); + return ccl::device_topology_type::undetermined; + } + + ccl::communicator_interface_ptr split(const ccl::comm_split_attr& attr) override; + + // collectives operation declarations + ccl::event barrier(const ccl::stream::impl_value_t& op_stream, + const ccl::barrier_attr& attr, + const ccl::vector_class& deps = {}) override; + ccl::event barrier_impl(const ccl::stream::impl_value_t& op_stream, + const ccl::barrier_attr& attr, + const ccl::vector_class& deps = {}); + + COMM_INTERFACE_COLL_METHODS(DEFINITION); +#ifdef CCL_ENABLE_SYCL + SYCL_COMM_INTERFACE_COLL_METHODS(DEFINITION); +#endif // CCL_ENABLE_SYCL + + COMM_IMPL_DECLARATION; + COMM_IMPL_CLASS_DECLARATION + COMM_IMPL_SPARSE_DECLARATION; + COMM_IMPL_SPARSE_CLASS_DECLARATION + + ccl_comm(); + ccl_comm(int size, ccl::shared_ptr_class kvs); + ccl_comm(int size, int rank, ccl::shared_ptr_class kvs); + ccl_comm(ccl::unified_device_type&& device, + ccl::unified_context_type&& context, + std::shared_ptr atl); + ccl_comm(std::shared_ptr atl); + +public: ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, - std::shared_ptr atl, + std::shared_ptr atl, bool share_resources = false, - ccl::host_communicator* host_comm = nullptr); + bool is_sub_communicator = false); ccl_comm(int rank, int size, ccl_comm_id_storage::comm_id&& id, ccl_rank2rank_map&& ranks, - std::shared_ptr atl, + std::shared_ptr atl, bool share_resources = false, - ccl::host_communicator* host_comm = nullptr); + bool is_sub_communicator = false); //TODO non-implemented //1) cluster_devices_count (devices 1000) -> (processes 10) @@ -80,51 +245,83 @@ class alignas(CACHELINE_SIZE) ccl_comm { std::shared_ptr kvs_instance, ccl_comm_id_storage::comm_id&& id, bool share_resources = false, - ccl::host_communicator* host_comm = nullptr); + bool is_sub_communicator = false); + +private: + // This is copy-constructor alike which basically means to copy-construct from src + // but replace m_id with id's value. + // We can't have a simple copy constructor here due to comm_id type limitation + ccl_comm(const ccl_comm& src, ccl_comm_id_storage::comm_id&& id); +public: + ccl_comm(ccl_comm& src) = delete; + ccl_comm(ccl_comm&& src) = default; + ccl_comm& operator=(ccl_comm& src) = delete; + ccl_comm& operator=(ccl_comm&& src) = default; ~ccl_comm() = default; + std::shared_ptr get_atl_comm() const; + std::shared_ptr get_r2r_comm(); + std::shared_ptr get_node_comm(); + std::shared_ptr get_even_comm(); + std::shared_ptr get_pair_comm(); - /* version with user-provided colors, allows to skip allgatherv */ - static ccl_comm* create_with_colors(const std::vector& colors, - ccl_comm_id_storage* comm_ids, - const ccl_comm* parent_comm, - bool share_resources = false); + // troubleshooting + std::string to_string() const; + std::string to_string_ext() const; - std::shared_ptr clone_with_new_id(ccl_comm_id_storage::comm_id&& id); + static constexpr int invalid_rank = -1; - int rank() const noexcept { - return m_rank; + /** + * Returns the number of @c rank in the global communicator + * @param rank a rank which is part of the current communicator + * @return number of @c rank in the global communicator + */ + int get_global_rank(int rank, bool only_global = false) const; + int get_rank_from_global(int global_rank) const; + + int rank() const override { + return comm_rank; } - int size() const noexcept { - return m_size; + int size() const override { + return comm_size; } int pof2() const noexcept { - return m_pof2; + return comm_impl->pof2(); } ccl_comm_id_t id() const noexcept { - return m_id.value(); + return comm_id->value(); } - size_t thread_count() const noexcept { - return thread_number; + const ccl_double_tree& dtree() const { + return comm_impl->dtree(); } - size_t ranks_per_process() const noexcept { - return on_process_ranks_number; + std::unique_ptr& get_unordered_coll_manager() { + return comm_impl->unordered_coll_manager; } + std::unique_ptr& get_allreduce_2d_builder() { + return comm_impl->allreduce_2d_builder; + } + + ccl_comm* create_with_color(int color, + ccl_comm_id_storage* comm_ids, + bool share_resources) const; + + std::shared_ptr clone_with_new_id(ccl_comm_id_storage::comm_id&& id); ccl_sched_id_t get_sched_id(bool use_internal_space) { ccl_sched_id_t& next_sched_id = - (use_internal_space) ? m_next_sched_id_internal : m_next_sched_id_external; + (use_internal_space) ? next_sched_id_internal : next_sched_id_external; - ccl_sched_id_t first_sched_id = - (use_internal_space) ? static_cast(0) : ccl_comm::max_sched_count / 2; + ccl_sched_id_t first_sched_id = (use_internal_space) + ? static_cast(0) + : ccl_comm_internal::max_sched_count / 2; - ccl_sched_id_t max_sched_id = - (use_internal_space) ? ccl_comm::max_sched_count / 2 : ccl_comm::max_sched_count; + ccl_sched_id_t max_sched_id = (use_internal_space) ? ccl_comm_internal::max_sched_count / 2 + : ccl_comm_internal::max_sched_count; ccl_sched_id_t id = next_sched_id; @@ -135,58 +332,50 @@ class alignas(CACHELINE_SIZE) ccl_comm { next_sched_id = first_sched_id; } - LOG_DEBUG("sched_id ", id, ", comm_id ", m_id.value(), ", next sched_id ", next_sched_id); + LOG_DEBUG("sched_id ", id, ", comm_id ", this->id(), ", next sched_id ", next_sched_id); return id; } - void reset(int rank, int size) { - m_rank = rank; - m_size = size; - m_pof2 = ccl_pof2(m_size); - - m_next_sched_id_internal = ccl_comm::max_sched_count / 2; - m_next_sched_id_external = 0; - } - - /** - * Returns the number of @c rank in the global communicator - * @param rank a rank which is part of the current communicator - * @return number of @c rank in the global communicator - */ - int get_global_rank(int rank) const; - - const ccl_double_tree& dtree() const { - return m_dtree; - } - /** * Maximum available number of active communicators */ - static constexpr ccl_sched_id_t max_comm_count = std::numeric_limits::max(); + static constexpr ccl_sched_id_t max_comm_count = ccl_comm_internal::max_comm_count; /** * Maximum value of schedule id in scope of the current communicator */ - static constexpr ccl_sched_id_t max_sched_count = std::numeric_limits::max(); + static constexpr ccl_sched_id_t max_sched_count = ccl_comm_internal::max_sched_count; - std::shared_ptr atl; - std::unique_ptr unordered_coll_manager; - std::unique_ptr allreduce_2d_builder; - -private: void allocate_resources(); - int m_rank; - int m_size; - int m_pof2; - - ccl_comm_id_storage::comm_id m_id; - ccl_sched_id_t m_next_sched_id_internal; - ccl_sched_id_t m_next_sched_id_external; - ccl_rank2rank_map m_local2global_map{}; - ccl_double_tree m_dtree; +private: + // This is an internal part of the communicator, we store there only the fileds should be shared + // across ccl_comm copies/clones. Everything else must go to ccl_comm. + std::shared_ptr comm_impl; + + ccl::unified_device_type device; + ccl::unified_context_type context; + + // TODO: double check if these can be moved to comm_impl as shared fields + std::shared_ptr r2r_comm; + std::shared_ptr node_comm; + std::shared_ptr even_comm; + std::shared_ptr pair_comm; + ccl::comm_split_attr comm_attr; + + // these fields are duplicate with the ones in ccl_comm_internal, but having them here + // allows to get them without going through the shared_ptr inderection. + int comm_rank; + int comm_size; + + // comm_id is not default constructible but ccl_comm is, so use unique_ptr here + std::unique_ptr comm_id; + ccl_sched_id_t next_sched_id_internal; + ccl_sched_id_t next_sched_id_external; + + ccl_comm* get_impl() { + return this; + } - size_t thread_number; - size_t on_process_ranks_number; - ccl::host_communicator* host_comm; -}; + void create_sub_comms(std::shared_ptr atl); +}; // class ccl_comm diff --git a/src/common/comm/comm_id_storage.hpp b/src/common/comm/comm_id_storage.hpp index 627bc7c9b..6254a161f 100644 --- a/src/common/comm/comm_id_storage.hpp +++ b/src/common/comm/comm_id_storage.hpp @@ -20,9 +20,6 @@ #include "common/utils/spinlock.hpp" #include -#include -#include -#include #include using ccl_comm_id_t = uint16_t; diff --git a/src/common/comm/host_communicator/host_communicator_impl.hpp b/src/common/comm/comm_impl.hpp similarity index 54% rename from src/common/comm/host_communicator/host_communicator_impl.hpp rename to src/common/comm/comm_impl.hpp index 00d8bd879..c5798f532 100644 --- a/src/common/comm/host_communicator/host_communicator_impl.hpp +++ b/src/common/comm/comm_impl.hpp @@ -15,7 +15,7 @@ */ #pragma once -#include "common/comm/host_communicator/host_communicator.hpp" +#include "common/comm/comm.hpp" #include "oneapi/ccl/native_device_api/interop_utils.hpp" #include "common/request/request.hpp" @@ -25,24 +25,22 @@ #include "coll/coll.hpp" #include "coll/coll_common_attributes.hpp" -namespace ccl { - /* allgatherv */ template -ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::allgatherv_impl(const buffer_type* send_buf, + size_t send_count, + buffer_type* recv_buf, + const ccl::vector_class& recv_counts, + const ccl::stream::impl_value_t& stream, + const ccl::allgatherv_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(send_buf), send_count, reinterpret_cast(recv_buf), recv_counts.data(), ccl::native_type_info::dtype, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -50,13 +48,13 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, - size_t send_count, - ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::allgatherv_impl(const buffer_type* send_buf, + size_t send_count, + ccl::vector_class& recv_bufs, + const ccl::vector_class& recv_counts, + const ccl::stream::impl_value_t& stream, + const ccl::allgatherv_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); internal_attr.is_vector_buf = 1; @@ -66,7 +64,7 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, recv_counts.data(), ccl::native_type_info::dtype, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -74,13 +72,13 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::allgatherv_impl(const buffer_type& send_buf, - size_t send_count, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::allgatherv_impl(const buffer_type& send_buf, + size_t send_count, + buffer_type& recv_buf, + const ccl::vector_class& recv_counts, + const ccl::stream::impl_value_t& stream, + const ccl::allgatherv_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -91,14 +89,14 @@ ccl::event host_communicator::allgatherv_impl(const buffer_type& send_buf, recv_counts.data(), ccl::native_type_info::dtype, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); return std::unique_ptr(new ccl::host_event_impl(req)); } template -ccl::event host_communicator::allgatherv_impl( +ccl::event ccl_comm::allgatherv_impl( const buffer_type& send_buf, size_t send_count, ccl::vector_class>& recv_bufs, @@ -117,7 +115,7 @@ ccl::event host_communicator::allgatherv_impl( recv_counts.data(), ccl::native_type_info::dtype, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -126,20 +124,20 @@ ccl::event host_communicator::allgatherv_impl( /* allreduce */ template -ccl::event host_communicator::allreduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::allreduce_impl(const buffer_type* send_buf, + buffer_type* recv_buf, + size_t count, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::allreduce_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_allreduce_impl(reinterpret_cast(send_buf), reinterpret_cast(recv_buf), count, ccl::native_type_info::dtype, reduction, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -147,13 +145,13 @@ ccl::event host_communicator::allreduce_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::allreduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::allreduce_impl(const buffer_type& send_buf, + buffer_type& recv_buf, + size_t count, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::allreduce_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -164,7 +162,7 @@ ccl::event host_communicator::allreduce_impl(const buffer_type& send_buf, ccl::native_type_info::dtype, reduction, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -173,18 +171,18 @@ ccl::event host_communicator::allreduce_impl(const buffer_type& send_buf, /* alltoall */ template -ccl::event host_communicator::alltoall_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::alltoall_impl(const buffer_type* send_buf, + buffer_type* recv_buf, + size_t count, + const ccl::stream::impl_value_t& stream, + const ccl::alltoall_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_alltoall_impl(reinterpret_cast(send_buf), reinterpret_cast(recv_buf), count, ccl::native_type_info::dtype, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -192,23 +190,23 @@ ccl::event host_communicator::alltoall_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::alltoall_impl(const ccl::vector_class& send_buf, + const ccl::vector_class& recv_buf, + size_t count, + const ccl::stream::impl_value_t& stream, + const ccl::alltoall_attr& attr, + const ccl::vector_class& deps) { throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } template -ccl::event host_communicator::alltoall_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::alltoall_impl(const buffer_type& send_buf, + buffer_type& recv_buf, + size_t count, + const ccl::stream::impl_value_t& stream, + const ccl::alltoall_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -218,7 +216,7 @@ ccl::event host_communicator::alltoall_impl(const buffer_type& send_buf, count, ccl::native_type_info::dtype, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -226,7 +224,7 @@ ccl::event host_communicator::alltoall_impl(const buffer_type& send_buf, } template -ccl::event host_communicator::alltoall_impl( +ccl::event ccl_comm::alltoall_impl( const ccl::vector_class>& send_buf, const ccl::vector_class>& recv_buf, size_t count, @@ -239,20 +237,20 @@ ccl::event host_communicator::alltoall_impl( /* alltoallv */ template -ccl::event host_communicator::alltoallv_impl(const buffer_type* send_buf, - const ccl::vector_class& send_counts, - buffer_type* recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::alltoallv_impl(const buffer_type* send_buf, + const ccl::vector_class& send_counts, + buffer_type* recv_buf, + const ccl::vector_class& recv_counts, + const ccl::stream::impl_value_t& stream, + const ccl::alltoallv_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_alltoallv_impl(reinterpret_cast(send_buf), send_counts.data(), reinterpret_cast(recv_buf), recv_counts.data(), ccl::native_type_info::dtype, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -260,25 +258,25 @@ ccl::event host_communicator::alltoallv_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::alltoallv_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - const ccl::vector_class& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { +ccl::event ccl_comm::alltoallv_impl(const ccl::vector_class& send_buf, + const ccl::vector_class& send_counts, + const ccl::vector_class& recv_buf, + const ccl::vector_class& recv_counts, + const ccl::stream::impl_value_t& stream, + const ccl::alltoallv_attr& attr, + const ccl::vector_class& dep) { throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } template -ccl::event host_communicator::alltoallv_impl(const buffer_type& send_buf, - const ccl::vector_class& send_counts, - buffer_type& recv_buf, - const ccl::vector_class& recv_counts, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::alltoallv_impl(const buffer_type& send_buf, + const ccl::vector_class& send_counts, + buffer_type& recv_buf, + const ccl::vector_class& recv_counts, + const ccl::stream::impl_value_t& stream, + const ccl::alltoallv_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -289,7 +287,7 @@ ccl::event host_communicator::alltoallv_impl(const buffer_type& send_buf, recv_counts.data(), ccl::native_type_info::dtype, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -297,7 +295,7 @@ ccl::event host_communicator::alltoallv_impl(const buffer_type& send_buf, } template -ccl::event host_communicator::alltoallv_impl( +ccl::event ccl_comm::alltoallv_impl( const ccl::vector_class>& send_buf, const ccl::vector_class& send_counts, const ccl::vector_class>& recv_buf, @@ -311,18 +309,18 @@ ccl::event host_communicator::alltoallv_impl( /* bcast */ template -ccl::event host_communicator::broadcast_impl(buffer_type* buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::broadcast_impl(buffer_type* buf, + size_t count, + int root, + const ccl::stream::impl_value_t& stream, + const ccl::broadcast_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_broadcast_impl(reinterpret_cast(buf), count, ccl::native_type_info::dtype, root, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -330,12 +328,12 @@ ccl::event host_communicator::broadcast_impl(buffer_type* buf, } template -ccl::event host_communicator::broadcast_impl(buffer_type& buf, - size_t count, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::broadcast_impl(buffer_type& buf, + size_t count, + int root, + const ccl::stream::impl_value_t& stream, + const ccl::broadcast_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -345,7 +343,7 @@ ccl::event host_communicator::broadcast_impl(buffer_type& buf, ccl::native_type_info::dtype, root, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -354,14 +352,14 @@ ccl::event host_communicator::broadcast_impl(buffer_type& buf, /* reduce */ template -ccl::event host_communicator::reduce_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::reduce_impl(const buffer_type* send_buf, + buffer_type* recv_buf, + size_t count, + ccl::reduction reduction, + int root, + const ccl::stream::impl_value_t& stream, + const ccl::reduce_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_reduce_impl(reinterpret_cast(send_buf), reinterpret_cast(recv_buf), count, @@ -369,7 +367,7 @@ ccl::event host_communicator::reduce_impl(const buffer_type* send_buf, reduction, root, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -377,14 +375,14 @@ ccl::event host_communicator::reduce_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::reduce_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t count, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::reduce_impl(const buffer_type& send_buf, + buffer_type& recv_buf, + size_t count, + ccl::reduction reduction, + int root, + const ccl::stream::impl_value_t& stream, + const ccl::reduce_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -396,7 +394,7 @@ ccl::event host_communicator::reduce_impl(const buffer_type& send_buf, reduction, root, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -405,20 +403,20 @@ ccl::event host_communicator::reduce_impl(const buffer_type& send_buf, /* reduce_scatter */ template -ccl::event host_communicator::reduce_scatter_impl(const buffer_type* send_buf, - buffer_type* recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::reduce_scatter_impl(const buffer_type* send_buf, + buffer_type* recv_buf, + size_t recv_count, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::reduce_scatter_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_reduce_scatter_impl(reinterpret_cast(send_buf), reinterpret_cast(recv_buf), recv_count, ccl::native_type_info::dtype, reduction, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -426,13 +424,13 @@ ccl::event host_communicator::reduce_scatter_impl(const buffer_type* send_buf, } template -ccl::event host_communicator::reduce_scatter_impl(const buffer_type& send_buf, - buffer_type& recv_buf, - size_t recv_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::reduce_scatter_impl(const buffer_type& send_buf, + buffer_type& recv_buf, + size_t recv_count, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::reduce_scatter_attr& attr, + const ccl::vector_class& deps) { ccl_coll_attr internal_attr(attr); #ifdef CCL_ENABLE_SYCL internal_attr.is_sycl_buf = 1; @@ -443,7 +441,7 @@ ccl::event host_communicator::reduce_scatter_impl(const buffer_type& send_buf, ccl::native_type_info::dtype, reduction, internal_attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -452,18 +450,18 @@ ccl::event host_communicator::reduce_scatter_impl(const buffer_type& send_buf, /* sparse_allreduce */ template -ccl::event host_communicator::sparse_allreduce_impl(const index_buffer_type* send_ind_buf, - size_t send_ind_count, - const value_buffer_type* send_val_buf, - size_t send_val_count, - index_buffer_type* recv_ind_buf, - size_t recv_ind_count, - value_buffer_type* recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::sparse_allreduce_impl(const index_buffer_type* send_ind_buf, + size_t send_ind_count, + const value_buffer_type* send_val_buf, + size_t send_val_count, + index_buffer_type* recv_ind_buf, + size_t recv_ind_count, + value_buffer_type* recv_val_buf, + size_t recv_val_count, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { ccl_request* req = ccl_sparse_allreduce_impl((const void*)send_ind_buf, send_ind_count, (const void*)send_val_buf, @@ -476,7 +474,7 @@ ccl::event host_communicator::sparse_allreduce_impl(const index_buffer_type* sen ccl::native_type_info::dtype, reduction, attr, - comm_impl.get(), + this, get_stream_ptr(stream), deps); @@ -484,20 +482,18 @@ ccl::event host_communicator::sparse_allreduce_impl(const index_buffer_type* sen } template -ccl::event host_communicator::sparse_allreduce_impl(const index_buffer_container_type& send_ind_buf, - size_t send_ind_count, - const value_buffer_container_type& send_val_buf, - size_t send_val_count, - index_buffer_container_type& recv_ind_buf, - size_t recv_ind_count, - value_buffer_container_type& recv_val_buf, - size_t recv_val_count, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { +ccl::event ccl_comm::sparse_allreduce_impl(const index_buffer_container_type& send_ind_buf, + size_t send_ind_count, + const value_buffer_container_type& send_val_buf, + size_t send_val_count, + index_buffer_container_type& recv_ind_buf, + size_t recv_ind_count, + value_buffer_container_type& recv_val_buf, + size_t recv_val_count, + ccl::reduction reduction, + const ccl::stream::impl_value_t& stream, + const ccl::sparse_allreduce_attr& attr, + const ccl::vector_class& deps) { throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); return {}; } - -} // namespace ccl diff --git a/src/common/comm/comm_interface.hpp b/src/common/comm/comm_interface.hpp index a72256413..8b230a00d 100644 --- a/src/common/comm/comm_interface.hpp +++ b/src/common/comm/comm_interface.hpp @@ -29,7 +29,6 @@ #include "oneapi/ccl/stream.hpp" #include "common/comm/compiler_comm_interface_dispatcher.hpp" -#include "common/comm/l0/comm_context_id.hpp" #include "internal_types.hpp" namespace native { @@ -48,8 +47,6 @@ class reduce_attr; class reduce_scatter_attr; class sparse_allreduce_attr; } // namespace v1 - -struct gpu_comm_attr; } // namespace ccl #include "types_generator_defines.hpp" @@ -149,8 +146,6 @@ struct communicator_interface : public communicator_interface_dispatcher { virtual bool is_ready() const = 0; - virtual const group_unique_key& get_comm_group_id() const = 0; - virtual ccl::communicator_interface_ptr split(const ccl::comm_split_attr& attr) = 0; // collectives operation declarations diff --git a/src/common/comm/communicator_traits.hpp b/src/common/comm/communicator_traits.hpp index cf0ea5d3e..b91629954 100644 --- a/src/common/comm/communicator_traits.hpp +++ b/src/common/comm/communicator_traits.hpp @@ -47,15 +47,4 @@ struct host_communicator_traits : base_communicator_traits { - static constexpr const char* name() { - return "cpu communicator"; - } -}; - -struct gpu_communicator_traits : base_communicator_traits { - static constexpr const char* name() { - return "gpu communicator"; - } -}; } // namespace ccl diff --git a/src/common/comm/compiler_comm_interface_dispatcher.cpp b/src/common/comm/compiler_comm_interface_dispatcher.cpp index 3c1373925..52d486e25 100644 --- a/src/common/comm/compiler_comm_interface_dispatcher.cpp +++ b/src/common/comm/compiler_comm_interface_dispatcher.cpp @@ -30,29 +30,29 @@ #include "common/global/global.hpp" -#ifdef MULTI_GPU_SUPPORT +#ifdef CCL_ENABLE_ZE #include "supported_topologies.hpp" #endif -#include "common/comm/host_communicator/host_communicator_impl.hpp" +#include "common/comm/comm_impl.hpp" namespace ccl { communicator_interface_ptr communicator_interface_dispatcher::create_communicator_impl() { - return communicator_interface_ptr(new host_communicator()); + return communicator_interface_ptr(new ccl_comm()); } communicator_interface_ptr communicator_interface_dispatcher::create_communicator_impl( const size_t size, shared_ptr_class kvs) { - return communicator_interface_ptr(new host_communicator(size, kvs)); + return communicator_interface_ptr(new ccl_comm(size, kvs)); } communicator_interface_ptr communicator_interface_dispatcher::create_communicator_impl( const size_t size, const int rank, shared_ptr_class kvs) { - return communicator_interface_ptr(new host_communicator(size, rank, kvs)); + return communicator_interface_ptr(new ccl_comm(size, rank, kvs)); } template atl, + std::shared_ptr atl, ccl::group_split_type preferred_topology_group /* = ccl::group_split_type::undetermined */) { static_assert(std::is_same::value, "Unsupported 'DeviceType'"); @@ -92,7 +92,7 @@ communicator_interface_ptr communicator_interface_dispatcher::create_communicato size_t thread_idx, size_t process_idx, const ccl::comm_split_attr& attr, - std::shared_ptr atl, + std::shared_ptr atl, ccl::group_split_type preferred_topology_group /* = ccl::group_split_type::undetermined */) { #ifdef CCL_ENABLE_SYCL return communicator_interface_dispatcher::create_communicator_from_unified_device( @@ -122,7 +122,7 @@ communicator_interface_dispatcher::create_communicator_from_unified_device( size_t thread_idx, size_t process_idx, const ccl::comm_split_attr& attr, - std::shared_ptr atl, + std::shared_ptr atl, ccl::group_split_type preferred_topology_group /* = ccl::group_split_type::undetermined */) { if (preferred_topology_group == ccl::group_split_type::undetermined) { preferred_topology_group = ccl::group_split_type::cluster; @@ -140,10 +140,10 @@ communicator_interface_dispatcher::create_communicator_from_unified_device( } switch (preferred_topology_group) { -#if defined(MULTI_GPU_SUPPORT) || defined(CCL_ENABLE_SYCL) +#if defined(CCL_ENABLE_ZE) || defined(CCL_ENABLE_SYCL) case ccl::group_split_type::single: { return communicator_interface_ptr( - new host_communicator(std::move(device_id), std::move(context), atl)); + new ccl_comm(std::move(device_id), std::move(context), atl)); } #endif default: @@ -164,7 +164,7 @@ communicator_interface_dispatcher::create_communicator_from_unified_device( size_t thread_idx, \ size_t process_idx, \ const ccl::comm_split_attr& attr, \ - std::shared_ptr atl, \ + std::shared_ptr atl, \ ccl::group_split_type \ preferred_topology_group /* = ccl::group_split_type::undetermined */); @@ -177,7 +177,7 @@ communicator_interface_dispatcher::create_communicator_from_unified_device( size_t thread_idx, \ size_t process_idx, \ const ccl::comm_split_attr& attr, \ - std::shared_ptr atl, \ + std::shared_ptr atl, \ ccl::group_split_type \ preferred_topology_group /* = ccl::group_split_type::undetermined */); diff --git a/src/common/comm/compiler_comm_interface_dispatcher.hpp b/src/common/comm/compiler_comm_interface_dispatcher.hpp index ad643f0a9..cb8ca8bb2 100644 --- a/src/common/comm/compiler_comm_interface_dispatcher.hpp +++ b/src/common/comm/compiler_comm_interface_dispatcher.hpp @@ -20,7 +20,7 @@ #include "oneapi/ccl/types.hpp" #include "supported_topologies.hpp" #include "communicator_traits.hpp" -#include "atl/atl_wrapper.h" +#include "atl/atl_base_comm.hpp" namespace native { struct ccl_device; @@ -30,9 +30,6 @@ namespace v1 { class comm_split_attr; } -#ifdef MULTI_GPU_SUPPORT -struct gpu_comm_attr; -#endif struct communicator_interface; using communicator_interface_ptr = std::shared_ptr; @@ -43,10 +40,6 @@ struct communicator_interface_dispatcher { virtual ~communicator_interface_dispatcher() = default; -#ifdef MULTI_GPU_SUPPORT - virtual void visit(ccl::gpu_comm_attr& comm_attr) = 0; -#endif //MULTI_GPU_SUPPORT - virtual ccl::device_index_type get_device_path() const = 0; virtual device_t get_device() const = 0; virtual context_t get_context() const = 0; @@ -66,7 +59,7 @@ struct communicator_interface_dispatcher { size_t thread_idx, size_t process_idx, const comm_split_attr& attr, - std::shared_ptr atl, + std::shared_ptr atl, ccl::group_split_type preferred_topology_group = ccl::group_split_type::undetermined); // create communicator for device & cpu types (from device index) @@ -81,7 +74,7 @@ struct communicator_interface_dispatcher { size_t thread_idx, size_t process_idx, const comm_split_attr& attr, - std::shared_ptr atl, + std::shared_ptr atl, ccl::group_split_type preferred_topology_group = ccl::group_split_type::undetermined); // create communicator for host @@ -103,7 +96,7 @@ struct communicator_interface_dispatcher { size_t thread_idx, size_t process_idx, const comm_split_attr& attr, - std::shared_ptr atl, + std::shared_ptr atl, ccl::group_split_type preferred_topology_group = ccl::group_split_type::undetermined); }; } // namespace ccl diff --git a/src/common/comm/host_communicator/host_communicator.cpp b/src/common/comm/host_communicator/host_communicator.cpp deleted file mode 100644 index b0883492d..000000000 --- a/src/common/comm/host_communicator/host_communicator.cpp +++ /dev/null @@ -1,499 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "common/global/global.hpp" -#include "common/comm/host_communicator/host_communicator_impl.hpp" -#include "oneapi/ccl/comm_split_attr_ids.hpp" -#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/comm_split_attr.hpp" - -#include "common/request/request.hpp" -#include "common/event/impls/host_event.hpp" -#include "coll/coll.hpp" -#include "coll/coll_common_attributes.hpp" -#include "coll/ccl_allgather_op_attr.hpp" - -#include "util/pm/pmi_resizable_rt/pmi_resizable/kvs/ikvs_wrapper.h" -#include "atl/atl_wrapper.h" - -#include "common/comm/comm.hpp" - -#ifdef MULTI_GPU_SUPPORT -#include "common/comm/l0/gpu_comm_attr.hpp" -#endif - -namespace ccl { - -using ccl::preview::create_comm_split_attr; - -host_communicator::host_communicator() - : device(ccl::device_index_type(ccl::unused_index_value, - ccl::unused_index_value, - ccl::unused_index_value)), - comm_attr(create_comm_split_attr()) {} - -host_communicator::host_communicator(int size, shared_ptr_class kvs) - : device(ccl::device_index_type(ccl::unused_index_value, - ccl::unused_index_value, - ccl::unused_index_value)), - comm_attr(create_comm_split_attr()), - comm_rank(0), - comm_size(size) { - if (size <= 0) { - throw ccl::exception("Incorrect size value when creating a host communicator"); - } -} - -host_communicator::host_communicator(int size, int rank, shared_ptr_class kvs) - : device(ccl::device_index_type(ccl::unused_index_value, - ccl::unused_index_value, - ccl::unused_index_value)), - comm_attr(create_comm_split_attr()), - comm_rank(rank), - comm_size(size) { - if (rank > size || size <= 0) { - throw ccl::exception("Incorrect rank or size value when creating a host communicator"); - } - - LOG_DEBUG("ctor"); - - ccl::global_data& data = ccl::global_data::get(); - std::shared_ptr atl_tmp = - std::shared_ptr(new atl_wrapper(size, { rank }, kvs)); - comm_impl = std::shared_ptr( - new ccl_comm(rank, size, data.comm_ids->acquire(), atl_tmp, false, this)); - create_sub_comms(atl_tmp); -} - -host_communicator::host_communicator(ccl::unified_device_type&& d, - ccl::unified_context_type&& c, - std::shared_ptr atl) - : host_communicator(atl) {} - -host_communicator::host_communicator(std::shared_ptr atl) - : device(ccl::device_index_type(ccl::unused_index_value, - ccl::unused_index_value, - ccl::unused_index_value)), - comm_attr(create_comm_split_attr()), - comm_rank(atl->get_rank()), - comm_size(atl->get_size()) { - int rank = atl->get_rank(); - int size = atl->get_size(); - - if (rank > size || size <= 0) { - throw ccl::exception("incorrect rank or size when creating \ - a host communicator: rank: " + - std::to_string(rank) + ", size: " + std::to_string(size)); - } - - LOG_DEBUG("ctor"); - - ccl::global_data& data = ccl::global_data::get(); - comm_impl = std::shared_ptr( - new ccl_comm(rank, size, data.comm_ids->acquire(), atl, false, this)); - create_sub_comms(atl); -} - -host_communicator::host_communicator(std::shared_ptr impl, bool is_sub_communicator) - : comm_impl(impl), - device(ccl::device_index_type(ccl::unused_index_value, - ccl::unused_index_value, - ccl::unused_index_value)), - comm_attr(create_comm_split_attr()), - comm_rank(impl->rank()), - comm_size(impl->size()) { - if (!is_sub_communicator) { - create_sub_comms(comm_impl.get()->atl); - } -} - -int host_communicator::rank() const { - return comm_rank; -} - -int host_communicator::size() const { - return comm_size; -} - -#ifdef MULTI_GPU_SUPPORT -void host_communicator::visit(ccl::gpu_comm_attr& comm_attr) { - (void)(comm_attr); -} -#endif - -ccl::device_index_type host_communicator::get_device_path() const { - return ccl::device_index_type{ ccl::unused_index_value, - ccl::unused_index_value, - ccl::unused_index_value }; -} - -ccl::communicator_interface::device_t host_communicator::get_device() const { - throw ccl::exception(std::string(__FUNCTION__) + " is not applicable for " + traits::name()); - static ccl::communicator_interface::device_t empty; - return empty; -} - -ccl::communicator_interface::context_t host_communicator::get_context() const { - throw ccl::exception(std::string(__FUNCTION__) + " is not applicable for " + traits::name()); - static ccl::communicator_interface::context_t empty; - return empty; -} - -void host_communicator::exchange_colors(std::vector& colors) { - size_t send_count = 1; - vector_class recv_counts(colors.size(), send_count); - auto attr = - create_operation_attr(attr_val(false)); - - this->allgatherv_impl(colors.data(), send_count, colors.data(), recv_counts, {}, attr, {}) - .wait(); -} - -void host_communicator::create_sub_comms(std::shared_ptr atl) { - bool is_sub_comm = true; - if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { - r2r_comm = - std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); - node_comm = - std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); - pair_comm = - std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); - even_comm = - std::shared_ptr(new host_communicator(comm_impl, is_sub_comm)); - } - else { - ccl::global_data& data = ccl::global_data::get(); - r2r_comm = std::shared_ptr( - new host_communicator(std::shared_ptr(this->create_with_color( - atl->get_r2r_color(), data.comm_ids.get(), comm_impl.get())), - is_sub_comm)); - node_comm = std::shared_ptr( - new host_communicator(std::shared_ptr(this->create_with_color( - atl->get_host_color(), data.comm_ids.get(), comm_impl.get())), - is_sub_comm)); - even_comm = std::shared_ptr(new host_communicator( - std::shared_ptr(this->create_with_color( - atl->get_host_color() + atl->get_rank() % 2, data.comm_ids.get(), comm_impl.get())), - is_sub_comm)); - pair_comm = std::shared_ptr(new host_communicator( - std::shared_ptr(this->create_with_color( - atl->get_host_color() + atl->get_rank() / 2, data.comm_ids.get(), comm_impl.get())), - is_sub_comm)); - } -} - -ccl_comm* host_communicator::create_with_color(int color, - ccl_comm_id_storage* comm_ids, - const ccl_comm* parent_comm) { - if (ccl::global_data::env().atl_transport == ccl_atl_mpi) { - throw ccl::exception( - "MPI transport doesn't support creation of communicator with color yet"); - } - - std::vector colors(this->size()); - colors[this->rank()] = color; - this->exchange_colors(colors); - - // TODO we can replace this func with own - return ccl_comm::create_with_colors(colors, comm_ids, parent_comm, true); -} - -ccl::communicator_interface_ptr host_communicator::split(const comm_split_attr& attr) { - if (!attr.is_valid()) { - throw ccl::exception(std::string(__FUNCTION__) + - " - 'Color' split attribute for host communicator is not set"); - } - - ccl::global_data& data = ccl::global_data::get(); - auto new_comm = this->create_with_color( - attr.get(), data.comm_ids.get(), comm_impl.get()); - - comm_attr = attr; - - return std::shared_ptr( - new host_communicator(std::shared_ptr(new_comm))); -} - -ccl::event host_communicator::barrier(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - return get_impl()->barrier_impl(stream, attr, deps); -} - -ccl::event host_communicator::barrier_impl(const ccl::stream::impl_value_t& stream, - const ccl::barrier_attr& attr, - const ccl::vector_class& deps) { - ccl_barrier_impl(comm_impl.get(), stream.get(), deps); - return std::unique_ptr(new ccl::host_event_impl(nullptr)); -} - -/* allgatherv */ -ccl::event host_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_allgatherv_impl(send_buf, - send_count, - recv_buf, - recv_counts.data(), - dtype, - attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -ccl::event host_communicator::allgatherv_impl(const void* send_buf, - size_t send_count, - const ccl::vector_class& recv_bufs, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::allgatherv_attr& attr, - const ccl::vector_class& deps) { - ccl_coll_attr internal_attr(attr); - internal_attr.is_vector_buf = 1; - - ccl_request* req = ccl_allgatherv_impl(reinterpret_cast(send_buf), - send_count, - (void*)(recv_bufs.data()), - recv_counts.data(), - dtype, - internal_attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* allreduce */ -ccl::event host_communicator::allreduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_allreduce_impl(send_buf, - recv_buf, - count, - dtype, - reduction, - attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* alltoall */ -ccl::event host_communicator::alltoall_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_alltoall_impl( - send_buf, recv_buf, count, dtype, attr, comm_impl.get(), get_stream_ptr(stream), deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -ccl::event host_communicator::alltoall_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& recv_buf, - size_t count, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoall_attr& attr, - const ccl::vector_class& deps) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* alltoallv */ -ccl::event host_communicator::alltoallv_impl(const void* send_buf, - const ccl::vector_class& send_counts, - void* recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_alltoallv_impl(send_buf, - send_counts.data(), - recv_buf, - recv_counts.data(), - dtype, - attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -ccl::event host_communicator::alltoallv_impl(const ccl::vector_class& send_buf, - const ccl::vector_class& send_counts, - ccl::vector_class recv_buf, - const ccl::vector_class& recv_counts, - ccl::datatype dtype, - const ccl::stream::impl_value_t& stream, - const ccl::alltoallv_attr& attr, - const ccl::vector_class& dep) { - // TODO not implemented - throw ccl::exception(std::string(__PRETTY_FUNCTION__) + " - is not implemented"); - return {}; -} - -/* bcast */ -ccl::event host_communicator::broadcast_impl(void* buf, - size_t count, - ccl::datatype dtype, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::broadcast_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_broadcast_impl( - buf, count, dtype, root, attr, comm_impl.get(), get_stream_ptr(stream), deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* reduce */ -ccl::event host_communicator::reduce_impl(const void* send_buf, - void* recv_buf, - size_t count, - ccl::datatype dtype, - ccl::reduction reduction, - int root, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_reduce_impl(send_buf, - recv_buf, - count, - dtype, - reduction, - root, - attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* reduce_scatter */ -ccl::event host_communicator::reduce_scatter_impl(const void* send_buf, - void* recv_buf, - size_t recv_count, - ccl::datatype dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::reduce_scatter_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_reduce_scatter_impl(send_buf, - recv_buf, - recv_count, - dtype, - reduction, - attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -/* sparse_allreduce */ -ccl::event host_communicator::sparse_allreduce_impl(const void* send_ind_buf, - size_t send_ind_count, - const void* send_val_buf, - size_t send_val_count, - void* recv_ind_buf, - size_t recv_ind_count, - void* recv_val_buf, - size_t recv_val_count, - ccl::datatype index_dtype, - ccl::datatype value_dtype, - ccl::reduction reduction, - const ccl::stream::impl_value_t& stream, - const ccl::sparse_allreduce_attr& attr, - const ccl::vector_class& deps) { - ccl_request* req = ccl_sparse_allreduce_impl(send_ind_buf, - send_ind_count, - send_val_buf, - send_val_count, - recv_ind_buf, - recv_ind_count, - recv_val_buf, - recv_val_count, - index_dtype, - value_dtype, - reduction, - attr, - comm_impl.get(), - get_stream_ptr(stream), - deps); - - return std::unique_ptr(new ccl::host_event_impl(req)); -} - -std::shared_ptr host_communicator::get_atl() { - return comm_impl->atl; -} - -std::shared_ptr host_communicator::get_r2r_comm() { - return r2r_comm; -} - -std::shared_ptr host_communicator::get_node_comm() { - return node_comm; -} - -std::shared_ptr host_communicator::get_pair_comm() { - return pair_comm; -} - -std::shared_ptr host_communicator::get_even_comm() { - return even_comm; -} - -std::shared_ptr host_communicator::get_ccl_comm() { - return comm_impl; -} - -std::string host_communicator::to_string() const { - return std::string("host communicator, rank (") + std::to_string(rank()) + "/" + - std::to_string(size()); -} - -COMM_INTERFACE_COLL_INSTANTIATION(host_communicator); -#ifdef CCL_ENABLE_SYCL -SYCL_COMM_INTERFACE_COLL_INSTANTIATION(host_communicator); -#endif // CCL_ENABLE_SYCL - -} // namespace ccl diff --git a/src/common/comm/host_communicator/host_communicator.hpp b/src/common/comm/host_communicator/host_communicator.hpp deleted file mode 100644 index 53bb642ce..000000000 --- a/src/common/comm/host_communicator/host_communicator.hpp +++ /dev/null @@ -1,179 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -#include "atl/atl_wrapper.h" -#include "common/comm/comm.hpp" -#include "common/stream/stream.hpp" -#include "oneapi/ccl/types.hpp" -#include "oneapi/ccl/types_policy.hpp" -#include "oneapi/ccl/comm_split_attr_ids.hpp" -#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/comm_split_attr.hpp" -#include "oneapi/ccl/types.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "oneapi/ccl/types_policy.hpp" -#include "oneapi/ccl/event.hpp" -#include "oneapi/ccl/coll_attr_ids.hpp" -#include "oneapi/ccl/coll_attr_ids_traits.hpp" -#include "oneapi/ccl/coll_attr.hpp" - -#include "common/comm/communicator_traits.hpp" -#include "common/comm/comm_interface.hpp" -#include "types_generator_defines.hpp" - -class ikvs_wrapper; -namespace ccl { - -inline ccl_stream* get_stream_ptr(const ccl::stream::impl_value_t& stream) { - if (stream.get() && stream->is_sycl_device_stream()) - return stream.get(); - else - return nullptr; -} - -class host_communicator : public ccl::communicator_interface { -public: - using traits = ccl::host_communicator_traits; - - int rank() const override; - int size() const override; - - // traits - bool is_host() const noexcept override { - return traits::is_host(); - } - - bool is_cpu() const noexcept override { - return traits::is_cpu(); - } - - bool is_gpu() const noexcept override { - return traits::is_gpu(); - } - - bool is_accelerator() const noexcept override { - return traits::is_accelerator(); - } - - bool is_ready() const override { - return true; - } - - const ccl::group_unique_key& get_comm_group_id() const override { - return owner_id; - } - - void set_comm_group_id(ccl::group_unique_key id) { - owner_id = id; - } - -#ifdef MULTI_GPU_SUPPORT - void visit(ccl::gpu_comm_attr& comm_attr) override; -#endif - - ccl::device_index_type get_device_path() const override; - ccl::communicator_interface::device_t get_device() const override; - ccl::communicator_interface::context_t get_context() const override; - - const ccl::comm_split_attr& get_comm_split_attr() const override { - return comm_attr; - } - - ccl::group_split_type get_topology_type() const override { - throw ccl::exception(std::string(__FUNCTION__) + " is not applicable for " + - traits::name()); - return ccl::group_split_type::undetermined; - } - - ccl::device_topology_type get_topology_class() const override { - throw ccl::exception(std::string(__FUNCTION__) + " is not applicable for " + - traits::name()); - return ccl::device_topology_type::undetermined; - } - - ccl::communicator_interface_ptr split(const comm_split_attr& attr) override; - - // collectives operation declarations - ccl::event barrier(const stream::impl_value_t& op_stream, - const barrier_attr& attr, - const vector_class& deps = {}) override; - ccl::event barrier_impl(const stream::impl_value_t& op_stream, - const barrier_attr& attr, - const vector_class& deps = {}); - - COMM_INTERFACE_COLL_METHODS(DEFINITION); -#ifdef CCL_ENABLE_SYCL - SYCL_COMM_INTERFACE_COLL_METHODS(DEFINITION); -#endif // CCL_ENABLE_SYCL - - COMM_IMPL_DECLARATION; - COMM_IMPL_CLASS_DECLARATION - COMM_IMPL_SPARSE_DECLARATION; - COMM_IMPL_SPARSE_CLASS_DECLARATION - - host_communicator(); - host_communicator(int size, shared_ptr_class kvs); - host_communicator(int size, int rank, shared_ptr_class kvs); - host_communicator(ccl::unified_device_type&& device, - ccl::unified_context_type&& context, - std::shared_ptr atl); - host_communicator(std::shared_ptr atl); - host_communicator(std::shared_ptr impl, bool is_sub_communicator = false); - host_communicator(host_communicator& src) = delete; - host_communicator(host_communicator&& src) = default; - host_communicator& operator=(host_communicator& src) = delete; - host_communicator& operator=(host_communicator&& src) = default; - ~host_communicator() = default; - std::shared_ptr get_atl(); - std::shared_ptr get_r2r_comm(); - std::shared_ptr get_node_comm(); - std::shared_ptr get_even_comm(); - std::shared_ptr get_pair_comm(); - std::shared_ptr get_ccl_comm(); - - // troubleshooting - std::string to_string() const; - -private: - friend struct group_context; - - std::shared_ptr comm_impl; - - ccl::unified_device_type device; - //ccl::unified_context_type context; - - std::shared_ptr r2r_comm; - std::shared_ptr node_comm; - std::shared_ptr even_comm; - std::shared_ptr pair_comm; - ccl::comm_split_attr comm_attr; - int comm_rank; - int comm_size; - ccl::group_unique_key owner_id; - - host_communicator* get_impl() { - return this; - } - - void exchange_colors(std::vector& colors); - void create_sub_comms(std::shared_ptr atl); - ccl_comm* create_with_color(int color, - ccl_comm_id_storage* comm_ids, - const ccl_comm* parent_comm); -}; // class host_communicator - -} // namespace ccl diff --git a/src/common/comm/l0/base_connector.hpp b/src/common/comm/l0/base_connector.hpp deleted file mode 100644 index e3f6ebda9..000000000 --- a/src/common/comm/l0/base_connector.hpp +++ /dev/null @@ -1,24 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#pragma once - -template -struct base_connector_interface { - using visitor = visitor_to_connect; - - virtual ~base_connector_interface() noexcept = default; - virtual bool operator()(visitor_to_connect& to_connect) = 0; -}; diff --git a/src/common/comm/l0/comm_context.cpp b/src/common/comm/l0/comm_context.cpp deleted file mode 100644 index 196079c12..000000000 --- a/src/common/comm/l0/comm_context.cpp +++ /dev/null @@ -1,58 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl/aliases.hpp" -#include "common/comm/host_communicator/host_communicator.hpp" -#include "common/comm/l0/comm_context_impl.hpp" -#include "common/utils/spinlock.hpp" -#include "common/comm/atl_tag.hpp" - -namespace ccl { -comm_group::comm_group(shared_communicator_t parent_comm, - size_t threads_per_process, - size_t ranks_per_process, - group_unique_key id) - : pimpl(new gpu_comm_attr(parent_comm, threads_per_process, ranks_per_process, id)){}; - -bool comm_group::sync_group_size(size_t device_group_size) { - return pimpl->sync_group_size(device_group_size); -} - -comm_group::~comm_group() {} - -const group_unique_key& comm_group::get_unique_id() const { - return pimpl->get_unique_id(); -} -/* -std::string comm_group::to_string() const -{ - pimpl->to_string(); -}*/ -} // namespace ccl -// container-based method force-instantiation will trigger ALL other methods instantiations -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(ccl::vector_class, - typename ccl::unified_context_type::ccl_native_t); -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(ccl::list_class, - typename ccl::unified_context_type::ccl_native_t); -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER(ccl::device_indices_type, - typename ccl::unified_context_type::ccl_native_t); -COMM_CREATOR_INDEXED_INSTANTIATION_TYPE(ccl::device_index_type, - typename ccl::unified_context_type::ccl_native_t); - -COMM_CREATOR_INDEXED_INSTANTIATION_CONTAINER( - ccl::vector_class, - typename ccl::unified_context_type::ccl_native_t); -COMM_CREATOR_INDEXED_INSTANTIATION_TYPE(typename ccl::unified_device_type::ccl_native_t, - typename ccl::unified_context_type::ccl_native_t); diff --git a/src/common/comm/l0/comm_context.hpp b/src/common/comm/l0/comm_context.hpp deleted file mode 100644 index 070e75ab9..000000000 --- a/src/common/comm/l0/comm_context.hpp +++ /dev/null @@ -1,118 +0,0 @@ -/* - Copyright 2016-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -#include "oneapi/ccl/aliases.hpp" -#include "oneapi/ccl/device_types.hpp" -#include "oneapi/ccl/type_traits.hpp" -#include "oneapi/ccl/types_policy.hpp" -#include "oneapi/ccl/comm_split_attr_ids.hpp" -#include "oneapi/ccl/comm_split_attr_ids_traits.hpp" -#include "oneapi/ccl/comm_split_attr.hpp" - -#include "oneapi/ccl/coll_attr_ids.hpp" -#include "oneapi/ccl/coll_attr_ids_traits.hpp" -#include "oneapi/ccl/coll_attr.hpp" - -#include "oneapi/ccl/stream_attr_ids.hpp" -#include "oneapi/ccl/stream_attr_ids_traits.hpp" -#include "oneapi/ccl/stream.hpp" - -#include "oneapi/ccl/event.hpp" -#include "oneapi/ccl/communicator.hpp" - -#include "common/comm/l0/comm_context_id.hpp" -#include "common/comm/comm_interface.hpp" - -namespace ccl { -namespace detail { -class environment; -} - -class host_communicator; -struct gpu_comm_attr; -using shared_communicator_t = std::shared_ptr; - -class comm_group { -public: - friend class ccl::detail::environment; - friend struct group_context; - - using context_t = typename unified_context_type::ccl_native_t; - - ~comm_group(); - /** - * Device Communicator creation API: single communicator creation, based on @device - */ - template ::type, - ccl::device_index_type>::value, - int>::type = 0> - ccl::communicator_interface_ptr create_communicator_from_group( - const DeviceType& device, - const ContextType& context, - const comm_split_attr& attr = ccl_empty_attr()); - - /** - * Device Communicator creation API: single communicator creation, based on index @device_id - */ - template ::type, - ccl::device_index_type>::value, - int>::type = 0> - ccl::communicator_interface_ptr create_communicator_from_group( - const DeviceType& device_id, - const ContextType& context, - const comm_split_attr& attr = ccl_empty_attr()); - - /** - * Device Communicator creation vectorized API: - * multiple communicator creation, based on devices iterator @InputIt - */ - template - std::vector create_communicators_group(InputIt first, - InputIt last, - const ContextType& context, - comm_split_attr attr = ccl_empty_attr()); - - /** - * Device Communicator creation vectorized API: - * multiple communicator creation, based on devices of @Type, packed into container @Container - */ - template