-
Notifications
You must be signed in to change notification settings - Fork 345
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d970e20
commit 3ce762b
Showing
6 changed files
with
316 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/extension.h" | ||
|
||
template <paddle::DataType D> | ||
class PDTraits; | ||
|
||
template <> | ||
class PDTraits<paddle::DataType::FLOAT32> { | ||
public: | ||
typedef float DataType; | ||
typedef float data_t; | ||
}; | ||
|
||
template <> | ||
class PDTraits<paddle::DataType::FLOAT16> { | ||
public: | ||
typedef half DataType; | ||
typedef paddle::float16 data_t; | ||
}; | ||
|
||
template <> | ||
class PDTraits<paddle::DataType::BFLOAT16> { | ||
public: | ||
typedef __nv_bfloat16 DataType; | ||
typedef paddle::bfloat16 data_t; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
#include <cub/block/block_radix_sort.cuh> | ||
#include <cub/warp/warp_reduce.cuh> | ||
#include <cub/block/block_load.cuh> | ||
#include <cub/block/block_discontinuity.cuh> | ||
#include <cub/block/block_store.cuh> | ||
#include <cub/block/block_reduce.cuh> | ||
#include <cub/cub.cuh> | ||
#include <math_constants.h> | ||
#include <thrust/host_vector.h> | ||
#include <thrust/device_vector.h> | ||
#include <mma.h> | ||
#include "helper.h" | ||
#include <iostream> | ||
using namespace std; | ||
|
||
#define HLF_MAX 65504 | ||
#define TH 1024 | ||
#define NUM 4 | ||
#define NUM_BLOCK 4096 | ||
|
||
__device__ unsigned char dQuantizeNF4(float x) | ||
{ | ||
|
||
// the values for this tree was generated by test_normal_map_tree | ||
// in the file tests/test_functional.py | ||
if(x > 0.03979014977812767f) | ||
if(x > 0.3893125355243683f) // 1 | ||
if(x > 0.6427869200706482f) // 11 | ||
if(x > 0.8614784181118011f) // 111 | ||
return 0b1111; | ||
else | ||
return 0b1110; | ||
else | ||
if(x > 0.5016634166240692f) // 110 | ||
return 0b1101; | ||
else | ||
return 0b1100; | ||
else | ||
if(x > 0.2035212516784668f) // 10 | ||
if(x > 0.2920137718319893f) // 101 | ||
return 0b1011; | ||
else | ||
return 0b1010; | ||
else | ||
if(x > 0.1202552504837513f) // 100 | ||
return 0b1001; | ||
else | ||
return 0b1000; | ||
else | ||
if(x > -0.33967943489551544f) // 0 | ||
if(x > -0.13791173323988914f) // 01 | ||
if(x > -0.045525018125772476f) // 011 | ||
return 0b0111; | ||
else | ||
return 0b0110; | ||
else | ||
if(x > -0.23460740596055984f) // 010 | ||
return 0b0101; | ||
else | ||
return 0b0100; | ||
else | ||
if(x > -0.6106329262256622f) // 00 | ||
if(x > -0.4599952697753906f) // 001 | ||
return 0b0011; | ||
else | ||
return 0b0010; | ||
else | ||
if(x > -0.8480964004993439f) // 000 | ||
return 0b0001; | ||
else | ||
return 0b0000; | ||
} | ||
|
||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH> | ||
//__launch_bounds__(TH, 4) | ||
__global__ void kQuantizeBlockwiseNF4(const T* A, float *absmax, unsigned char *out, const int n) | ||
{ | ||
// 所有的 CUDA blocks 处理的所有元素个数 | ||
const int n_full = gridDim.x * BLOCK_SIZE; | ||
int valid_items = 0; | ||
// 当前 CUDA block 处理元素的起始索引 | ||
const int base_idx = (blockIdx.x * BLOCK_SIZE); | ||
// 当前 CUDA thread 处理的输入元素 | ||
T vals[NUM_PER_TH]; | ||
// 当前 CUDA thread 处理的输出元素个数 | ||
const int output_num_per_thread = NUM_PER_TH/2; | ||
// 当前 CUDA thread 处理的输出元素 | ||
unsigned char qvals[output_num_per_thread]; | ||
//float local_abs_max = -FLT_MAX; | ||
float local_abs_max = 0.0f; | ||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT; | ||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH/2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; | ||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce; | ||
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat; | ||
|
||
__shared__ typename LoadT::TempStorage loadt; | ||
__shared__ typename LoadFloat::TempStorage loadf; | ||
__shared__ typename StoreChar::TempStorage storec; | ||
__shared__ typename BlockReduce::TempStorage reduce; | ||
// 每个CUDA block (也是每个 quantization block)的absmax | ||
__shared__ float smem_absmax_value[1]; | ||
|
||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) | ||
{ | ||
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; | ||
local_abs_max = -FLT_MAX; | ||
|
||
__syncthreads(); | ||
LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); | ||
|
||
// 1. compute local max | ||
// 2. broadcast local max | ||
// 3. normalize inputs and quantize | ||
|
||
#pragma unroll NUM_PER_TH | ||
for(int j = 0; j < NUM_PER_TH; j++) | ||
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); | ||
|
||
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); | ||
|
||
if(threadIdx.x == 0) | ||
smem_absmax_value[0] = local_abs_max; | ||
|
||
__syncthreads(); | ||
|
||
if(threadIdx.x == 0) | ||
absmax[i/BLOCK_SIZE] = local_abs_max; | ||
else | ||
local_abs_max = smem_absmax_value[0]; | ||
|
||
__syncwarp(); | ||
|
||
local_abs_max = 1.0f/local_abs_max; | ||
|
||
unsigned char packed_4bit = 0; | ||
|
||
#pragma unroll NUM_PER_TH | ||
for(int j = 0; j < NUM_PER_TH/2; j++) | ||
{ | ||
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; | ||
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); | ||
qvals[j] = packed_4bit; | ||
} | ||
|
||
__syncthreads(); | ||
StoreChar(storec).Store(&(out[i/2]), qvals, (valid_items+1)/2); | ||
} | ||
} | ||
|
||
#define MAKE_kQuantizeBlockwiseNF4(dtype, blocksize, num_per_thread) \ | ||
template __global__ void kQuantizeBlockwiseNF4<dtype, blocksize, num_per_thread>(const dtype * A, float *absmax, unsigned char *out, const int n); \ | ||
|
||
MAKE_kQuantizeBlockwiseNF4(half, 4096, 4) | ||
MAKE_kQuantizeBlockwiseNF4(half, 1024, 4) | ||
MAKE_kQuantizeBlockwiseNF4(half, 512, 2) | ||
MAKE_kQuantizeBlockwiseNF4(half, 256, 2) | ||
MAKE_kQuantizeBlockwiseNF4(half, 128, 2) | ||
MAKE_kQuantizeBlockwiseNF4(half, 64, 2) | ||
|
||
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 4096, 4) | ||
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 1024, 4) | ||
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 512, 2) | ||
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 256, 2) | ||
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 128, 2) | ||
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 64, 2) | ||
|
||
MAKE_kQuantizeBlockwiseNF4(float, 4096, 4) | ||
MAKE_kQuantizeBlockwiseNF4(float, 1024, 4) | ||
MAKE_kQuantizeBlockwiseNF4(float, 512, 2) | ||
MAKE_kQuantizeBlockwiseNF4(float, 256, 2) | ||
MAKE_kQuantizeBlockwiseNF4(float, 128, 2) | ||
MAKE_kQuantizeBlockwiseNF4(float, 64, 2) | ||
|
||
template <paddle::DataType D> | ||
std::vector<paddle::Tensor> LaunchQuantizeNF4(const paddle::Tensor& input, int block_size) { | ||
cout << "LaunchQuantizeNF4 begin-------" << endl; | ||
typedef PDTraits<D> traits_; | ||
typedef typename traits_::DataType DataType_; | ||
typedef typename traits_::data_t data_t; | ||
auto input_shape = input.shape(); | ||
auto output = paddle::full(input_shape, 1, paddle::DataType::UINT8, input.place()); | ||
const int n = input.numel(); | ||
int num_blocks = n/block_size; | ||
num_blocks = n % block_size == 0 ? num_blocks : num_blocks + 1; | ||
|
||
auto abs_max = paddle::full({num_blocks}, 1, paddle::DataType::FLOAT32, input.place()); | ||
|
||
const DataType_ *in_ptr = reinterpret_cast<const DataType_*>(input.data<data_t>()); | ||
unsigned char *out_ptr = output.mutable_data<unsigned char>(); | ||
float *abs_max_ptr = abs_max.mutable_data<float>(); | ||
|
||
if(block_size == 2048) { | ||
kQuantizeBlockwiseNF4<DataType_, 2048, 4><<<num_blocks, 512>>>(in_ptr, abs_max_ptr, out_ptr, n); | ||
} else if(block_size == 1024) { | ||
kQuantizeBlockwiseNF4<DataType_, 1024, 4><<<num_blocks, 256>>>(in_ptr, abs_max_ptr, out_ptr, n); | ||
} else if(block_size == 512) { | ||
kQuantizeBlockwiseNF4<DataType_, 512, 2><<<num_blocks, 256>>>(in_ptr, abs_max_ptr, out_ptr, n); | ||
} else if(block_size == 256) { | ||
kQuantizeBlockwiseNF4<DataType_, 256, 2><<<num_blocks, 128>>>(in_ptr, abs_max_ptr, out_ptr, n); | ||
} else if(block_size == 128) { | ||
kQuantizeBlockwiseNF4<DataType_, 128, 2><<<num_blocks, 64>>>(in_ptr, abs_max_ptr, out_ptr, n); | ||
} else if(block_size == 64) { | ||
kQuantizeBlockwiseNF4<DataType_, 64, 2><<<num_blocks, 32>>>(in_ptr, abs_max_ptr, out_ptr, n); | ||
} | ||
return {output, abs_max}; | ||
} | ||
|
||
std::vector<paddle::Tensor> QuantizeNF4(const paddle::Tensor& input, int block_size) { | ||
cout << "QuantizeNF4 begin-------" << endl; | ||
switch (input.type()) { | ||
case paddle::DataType::BFLOAT16: { | ||
return LaunchQuantizeNF4<paddle::DataType::BFLOAT16>(input, block_size); | ||
} | ||
case paddle::DataType::FLOAT16: { | ||
return LaunchQuantizeNF4<paddle::DataType::FLOAT16>(input, block_size); | ||
} | ||
case paddle::DataType::FLOAT32: { | ||
return LaunchQuantizeNF4<paddle::DataType::FLOAT32>(input, block_size); | ||
} | ||
default: { | ||
PD_THROW( | ||
"NOT supported data type. " | ||
"Only bfloat16, float16 and float32 are supported. "); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
|
||
|
||
|
||
PD_BUILD_OP(quantize_nf4) | ||
.Inputs({"input"}) | ||
.Outputs({"out", "abs_max"}) | ||
.SetKernelFn(PD_KERNEL(QuantizeNF4)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
cupy-cuda116 | ||
pybind11 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.utils.cpp_extension import CUDAExtension, setup | ||
|
||
setup( | ||
name="paddleslim_ops", | ||
ext_modules=CUDAExtension(sources=[ | ||
"./lc/nf4.cu", | ||
]), ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters