Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AMD MI300 GPU #155

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "third_party/msccl"]
path = third_party/msccl
url = https://github.com/Azure/msccl-executor-nccl
branch = msccl-v2.17
branch = msccl-v2.17
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ lint: cpplint mdlint
python3 setup.py lint

postinstall:
cd msamp/operators/dist_op && bash build.sh && cd -
cd msamp/operators/arithmetic && pip install -v . && cd -
cd msamp/optim && pip install -v . && cd -
cd csrc/extensions && pip install -v . && cd -
cd csrc/dist_op && bash build.sh && cd -
29 changes: 29 additions & 0 deletions csrc/dist_op/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

cmake_minimum_required(VERSION 3.10)

project(msamp_dist LANGUAGES CXX)

list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/)

find_package(CUDA QUIET)

if (CUDA_FOUND)
find_package(NCCL REQUIRED)
add_library(msamp_dist SHARED dist.cpp)
target_include_directories(msamp_dist PUBLIC ${CUDA_INCLUDE_DIRS} ${NCCL_INCLUDE_DIR})
target_link_libraries(msamp_dist PUBLIC ${CUDA_LIBRARIES} ${NCCL_LIBRARY})
else()
enable_language(HIP)
find_package(HIP REQUIRED)
find_package(RCCL REQUIRED)
add_library(msamp_dist SHARED dist.cpp)

target_include_directories(msamp_dist PUBLIC ${HIP_INCLUDE_DIRS} ${RCCL_INCLUDE_DIR})
target_link_libraries(msamp_dist PUBLIC ${HIP_LIBRARIES} ${RCCL_LIBRARY})
set_source_files_properties(dist.cpp PROPERTIES LANGUAGE HIP)

endif()

install(TARGETS msamp_dist)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY
if (NCCL_FOUND)
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library: ${NCCL_LIBRARY})")
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}.${NCCL_PATCH}")

endif ()
44 changes: 44 additions & 0 deletions csrc/dist_op/FindRCCL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# Find the RCCL libraries
#
# The following variables are optionally searched for defaults
# ROCM_PATH: Base directory where all RCCL components are found
#
# The following are set after configuration is done:
# RCCL_FOUND
# RCCL_INCLUDE_DIR
# RCCL_LIBRARY

if(NOT DEFINED ENV{ROCM_PATH})
# Run hipconfig -p to get ROCm path
execute_process(
COMMAND hipconfig -R
RESULT_VARIABLE HIPCONFIG_RESULT
OUTPUT_VARIABLE ROCM_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)

# Check if hipconfig was successful
if(NOT HIPCONFIG_RESULT EQUAL 0)
message(FATAL_ERROR "Failed to run hipconfig -p. Make sure ROCm is installed and hipconfig is available.")
endif()
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()

find_path(RCCL_INCLUDE_DIR NAMES rccl.h
PATHS ${ROCM_PATH}/include/rccl /usr/local/include/rccl)

find_library(RCCL_LIBRARY NAMES rccl
PATHS ${ROCM_PATH}/lib /usr/local/lib)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(RCCL DEFAULT_MSG RCCL_INCLUDE_DIR RCCL_LIBRARY)

if (RCCL_FOUND)
message(STATUS "Found RCCL (include: ${RCCL_INCLUDE_DIR}, library: ${RCCL_LIBRARY})")
mark_as_advanced(RCCL_INCLUDE_DIR RCCL_LIBRARY)

endif ()
File renamed without changes.
31 changes: 16 additions & 15 deletions msamp/operators/dist_op/dist.cpp → csrc/dist_op/dist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
// Licensed under the MIT License.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <dlfcn.h>
#include <execinfo.h>
#include <unistd.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <rccl.h>
#define stream_t hipStream_t
#else
#include <nccl.h>
#include <nccl_net.h>
#define stream_t cudaStream_t

#endif

enum FP8ModeType {kFP8Disabled, kFP8E4M3, kFp8E5M2};
static enum FP8ModeType gFP8Mode = kFP8Disabled;


extern "C"
void disable_fp8() {
gFP8Mode = kFP8Disabled;
Expand All @@ -39,9 +39,9 @@ void enable_fp8_e5m2() {
#undef ncclAllReduce
extern "C"
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, stream_t stream) {
using ncclAllReduceFuncType = ncclResult_t (*)
(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t);
(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, stream_t);
ncclAllReduceFuncType real_nccl_all_reduce = reinterpret_cast<ncclAllReduceFuncType>(dlsym(RTLD_NEXT, "ncclAllReduce"));
if (real_nccl_all_reduce == nullptr) {
printf("MSAMP_DistOp: Failed to find ncclAllReduce symbol");
Expand All @@ -54,6 +54,7 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
} else {
return real_nccl_all_reduce(sendbuff, recvbuff, count, datatype, op, comm, stream);
}
return real_nccl_all_reduce(sendbuff, recvbuff, count, datatype, op, comm, stream);

}

Expand All @@ -63,10 +64,10 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
#undef ncclReduce
extern "C"
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream)
ncclRedOp_t op, int root, ncclComm_t comm, stream_t stream)
{
using ncclReduceFuncType = ncclResult_t (*)
(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, int, ncclComm_t, cudaStream_t);
(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, int, ncclComm_t, stream_t);
ncclReduceFuncType real_nccl_reduce = reinterpret_cast<ncclReduceFuncType>(dlsym(RTLD_NEXT, "ncclReduce"));
if (real_nccl_reduce == nullptr) {
printf("MSAMP_DistOp: Failed to find ncclReduce symbol");
Expand All @@ -86,9 +87,9 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncc
*/
#undef ncclReduceScatter
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, stream_t stream) {
using ncclReduceScatterFuncType = ncclResult_t (*)
(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm*, cudaStream_t);
(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm*, stream_t);
ncclReduceScatterFuncType real_nccl_reduce_scatter = reinterpret_cast<ncclReduceScatterFuncType>(dlsym(RTLD_NEXT, "ncclReduceScatter"));
if (real_nccl_reduce_scatter == nullptr) {
printf("MSAMP_DistOp: Failed to find ncclReduceScatter symbol");
Expand Down
29 changes: 16 additions & 13 deletions msamp/common/include/common.h → csrc/extensions/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
#ifndef MSAMP_COMMON_H_
#define MSAMP_COMMON_H_

#include <cublasLt.h>
#include <cublas_v2.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <torch/extension.h>
#else
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#include "hip_float8.h"
#endif

#include <torch/extension.h>
#include <string>

using namespace std;
Expand All @@ -19,10 +24,18 @@ using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;

#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

#else
using bf16 = hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif

template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
Expand Down Expand Up @@ -98,16 +111,6 @@ inline int HIP_GET_BLOCKS(const int n, const int num_threads) {

#define CUDA_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)

template <typename T, typename S> __host__ __device__ T cast_dtype(const S value) { return T(value); }

template <> __host__ __device__ fp16 cast_dtype(const float value) { return __float2half(value); }

template <> __host__ __device__ bf16 cast_dtype(const float value) { return __float2bfloat16(value); }

template <> __host__ __device__ float cast_dtype(const fp16 value) { return __half2float(value); }

template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); }

template <typename T>
struct is_fp8 : std::false_type {};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct DeviceSyncer {
/// Destroy the DeviceSyncer object.
~DeviceSyncer() = default;

#ifdef __CUDACC__
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is
/// finished.
/// @param blockNum The number of blocks that will synchronize.
Expand Down Expand Up @@ -48,7 +47,6 @@ struct DeviceSyncer {
// the flag is flipped.
__syncthreads();
}
#endif

private:
/// The flag to indicate whether the barrier is reached by the latest thread.
Expand Down
20 changes: 20 additions & 0 deletions csrc/extensions/include/extensions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <torch/extension.h>

namespace msamp {
void add_to_fp8(at::Tensor fp8_tensor,
at::Tensor scale,
at::Tensor scale_inv,
at::Tensor amax,
const at::Tensor& other,
bool is_e4m3);

void adamw_fp8_stage1_compute(at::Tensor param, at::Tensor grad, at::Tensor exp_avg_value, float exp_avg_factor,
at::Tensor exp_avg_amax, float beta1, at::Tensor exp_avg_sq_value, float exp_avg_sq_factor,
at::Tensor exp_avg_sq_amax, float beta2, float eps, int step, float lr,
bool bias_correction);

void adamw_fp8_stage2_compute(at::Tensor grad, at::Tensor exp_avg_value, float exp_avg_factor, float new_exp_avg_factor,
float beta1, at::Tensor exp_avg_sq_value, float exp_avg_sq_factor,
float new_exp_avg_sq_factor, float beta2, int step, bool bias_correction);

}
Loading
Loading