Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMD Support #187

Merged
merged 9 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality**
[Project Kanban](https://github.com/orgs/SJTU-IPADS/projects/2/views/2)

## Latest News 🔥
- [2024/5/20] **Competition Recruitment: CCF-TCArch Customized Computing Challenge 2024**. The CCF TCARCH CCC is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). This year's competition aims to optimize the PowerInfer inference engine using the open-source ROCm/HIP. More information about the competition can be found [here](https://ccf-tcarch-ccc.github.io/2024/).
- [2024/5/17] We now provide support for AMD devices with ROCm. (WIP - there are known issues for models exceeding 40B).
- [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf).
- [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)!
- [2024/1/11] We supported Windows with GPU inference!
- [2023/12/24] We released an online [gradio demo](https://powerinfer-gradio.vercel.app/) for Falcon(ReLU)-40B-FP16!
- [2023/12/19] We officially released PowerInfer!

## Demo 🔥

https://github.com/SJTU-IPADS/PowerInfer/assets/34213478/fe441a42-5fce-448b-a3e5-ea4abb43ba23
Expand Down Expand Up @@ -102,6 +105,7 @@ cd PowerInfer
pip install -r requirements.txt # install Python helpers' dependencies
```
### Build

In order to build PowerInfer you have two different options. These commands are supposed to be run from the root directory of the project.

Using `CMake`(3.17+):
Expand All @@ -110,7 +114,15 @@ Using `CMake`(3.17+):
cmake -S . -B build -DLLAMA_CUBLAS=ON
cmake --build build --config Release
```
* If you have an AMD GPU:
```bash
# Replace '1100' to your card architecture name, you can get it by rocminfo
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake -S . -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1100
cmake --build build --config Release
```

* If you have just CPU:

```bash
cmake -S . -B build
cmake --build build --config Release
Expand Down
49 changes: 42 additions & 7 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4614,6 +4614,44 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
}
}

// nrows: 11008(or 32 * x < 11008), ncols: 4096
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_axpy_sparse_lessatom(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) {
int warp_id = threadIdx.y;
int tid = threadIdx.x + blockIdx.x * 32;
int col = tid * 2;
dfloat2 v;
int iqs = (col % qk) / qr;
float tmp[2];
tmp[0] = 0.0;
tmp[1] = 0.0;
__shared__ float res[64];
res[threadIdx.x] = 0.0;
res[threadIdx.x + 32] = 0.0;

#pragma unroll 32
for (int row = warp_id; row < nrows; row += 32) {
int raw_row = lst ? lst[row] : row;
// int raw_row = row;
dfloat y_row = y[raw_row];
if (y_row == 0.0) {
continue;
}
const int ib = (row * ncols + col) / qk;
dequantize_kernel(vx, ib, iqs, v);
tmp[0] += v.x * y_row;
tmp[1] += v.y * y_row;
}
const int adder_loc = threadIdx.x % 16 + threadIdx.x / 16 * 32;
atomicAdd(res + adder_loc, tmp[0]);
atomicAdd(res + adder_loc + 16, tmp[1]);
__syncthreads();
if (warp_id < 1) {
int write_back_loc = warp_id * 32 + threadIdx.x;
dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc];
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int * lst, float * idx) {
// qk = quantized weights per x block
Expand Down Expand Up @@ -5598,13 +5636,10 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo
}
static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
// dequantize_mul_mat_axpy<QK4_0, QR4_0, dequantize_q4_0>
// <<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows);
dequantize_mul_mat_axpy_sparse<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
const dim3 block_dim = dim3(32, 32);
const int block_num = (ncols + 63) / 64;
dequantize_mul_mat_axpy_sparse_lessatom<QK4_0, QR4_0, dequantize_q4_0>
<<<block_num, block_dim, 0, stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}

static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
Expand Down
Loading