Skip to content

Commit

Permalink
Add NF4 quantizae kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang committed Nov 2, 2023
1 parent d970e20 commit 3ce762b
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 14 deletions.
41 changes: 41 additions & 0 deletions csrc/lc/helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/extension.h"

template <paddle::DataType D>
class PDTraits;

template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};

template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
};

template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef __nv_bfloat16 DataType;
typedef paddle::bfloat16 data_t;
};
235 changes: 235 additions & 0 deletions csrc/lc/nf4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <mma.h>
#include "helper.h"
#include <iostream>
using namespace std;

#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096

__device__ unsigned char dQuantizeNF4(float x)
{

// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else
if(x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else
if(x > 0.2035212516784668f) // 10
if(x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else
if(x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1000;
else
if(x > -0.33967943489551544f) // 0
if(x > -0.13791173323988914f) // 01
if(x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else
if(x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else
if(x > -0.6106329262256622f) // 00
if(x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else
if(x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}

template<typename T, int BLOCK_SIZE, int NUM_PER_TH>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwiseNF4(const T* A, float *absmax, unsigned char *out, const int n)
{
// 所有的 CUDA blocks 处理的所有元素个数
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
// 当前 CUDA block 处理元素的起始索引
const int base_idx = (blockIdx.x * BLOCK_SIZE);
// 当前 CUDA thread 处理的输入元素
T vals[NUM_PER_TH];
// 当前 CUDA thread 处理的输出元素个数
const int output_num_per_thread = NUM_PER_TH/2;
// 当前 CUDA thread 处理的输出元素
unsigned char qvals[output_num_per_thread];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH/2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename BlockReduce::TempStorage reduce;
// 每个CUDA block (也是每个 quantization block)的absmax
__shared__ float smem_absmax_value[1];

for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX;

__syncthreads();
LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);

// 1. compute local max
// 2. broadcast local max
// 3. normalize inputs and quantize

#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);

if(threadIdx.x == 0)
smem_absmax_value[0] = local_abs_max;

__syncthreads();

if(threadIdx.x == 0)
absmax[i/BLOCK_SIZE] = local_abs_max;
else
local_abs_max = smem_absmax_value[0];

__syncwarp();

local_abs_max = 1.0f/local_abs_max;

unsigned char packed_4bit = 0;

#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}

__syncthreads();
StoreChar(storec).Store(&(out[i/2]), qvals, (valid_items+1)/2);
}
}

#define MAKE_kQuantizeBlockwiseNF4(dtype, blocksize, num_per_thread) \
template __global__ void kQuantizeBlockwiseNF4<dtype, blocksize, num_per_thread>(const dtype * A, float *absmax, unsigned char *out, const int n); \

MAKE_kQuantizeBlockwiseNF4(half, 4096, 4)
MAKE_kQuantizeBlockwiseNF4(half, 1024, 4)
MAKE_kQuantizeBlockwiseNF4(half, 512, 2)
MAKE_kQuantizeBlockwiseNF4(half, 256, 2)
MAKE_kQuantizeBlockwiseNF4(half, 128, 2)
MAKE_kQuantizeBlockwiseNF4(half, 64, 2)

MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 4096, 4)
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 1024, 4)
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 512, 2)
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 256, 2)
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 128, 2)
MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 64, 2)

MAKE_kQuantizeBlockwiseNF4(float, 4096, 4)
MAKE_kQuantizeBlockwiseNF4(float, 1024, 4)
MAKE_kQuantizeBlockwiseNF4(float, 512, 2)
MAKE_kQuantizeBlockwiseNF4(float, 256, 2)
MAKE_kQuantizeBlockwiseNF4(float, 128, 2)
MAKE_kQuantizeBlockwiseNF4(float, 64, 2)

template <paddle::DataType D>
std::vector<paddle::Tensor> LaunchQuantizeNF4(const paddle::Tensor& input, int block_size) {
cout << "LaunchQuantizeNF4 begin-------" << endl;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto input_shape = input.shape();
auto output = paddle::full(input_shape, 1, paddle::DataType::UINT8, input.place());
const int n = input.numel();
int num_blocks = n/block_size;
num_blocks = n % block_size == 0 ? num_blocks : num_blocks + 1;

auto abs_max = paddle::full({num_blocks}, 1, paddle::DataType::FLOAT32, input.place());

const DataType_ *in_ptr = reinterpret_cast<const DataType_*>(input.data<data_t>());
unsigned char *out_ptr = output.mutable_data<unsigned char>();
float *abs_max_ptr = abs_max.mutable_data<float>();

if(block_size == 2048) {
kQuantizeBlockwiseNF4<DataType_, 2048, 4><<<num_blocks, 512>>>(in_ptr, abs_max_ptr, out_ptr, n);
} else if(block_size == 1024) {
kQuantizeBlockwiseNF4<DataType_, 1024, 4><<<num_blocks, 256>>>(in_ptr, abs_max_ptr, out_ptr, n);
} else if(block_size == 512) {
kQuantizeBlockwiseNF4<DataType_, 512, 2><<<num_blocks, 256>>>(in_ptr, abs_max_ptr, out_ptr, n);
} else if(block_size == 256) {
kQuantizeBlockwiseNF4<DataType_, 256, 2><<<num_blocks, 128>>>(in_ptr, abs_max_ptr, out_ptr, n);
} else if(block_size == 128) {
kQuantizeBlockwiseNF4<DataType_, 128, 2><<<num_blocks, 64>>>(in_ptr, abs_max_ptr, out_ptr, n);
} else if(block_size == 64) {
kQuantizeBlockwiseNF4<DataType_, 64, 2><<<num_blocks, 32>>>(in_ptr, abs_max_ptr, out_ptr, n);
}
return {output, abs_max};
}

std::vector<paddle::Tensor> QuantizeNF4(const paddle::Tensor& input, int block_size) {
cout << "QuantizeNF4 begin-------" << endl;
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
return LaunchQuantizeNF4<paddle::DataType::BFLOAT16>(input, block_size);
}
case paddle::DataType::FLOAT16: {
return LaunchQuantizeNF4<paddle::DataType::FLOAT16>(input, block_size);
}
case paddle::DataType::FLOAT32: {
return LaunchQuantizeNF4<paddle::DataType::FLOAT32>(input, block_size);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only bfloat16, float16 and float32 are supported. ");
break;
}
}
}




PD_BUILD_OP(quantize_nf4)
.Inputs({"input"})
.Outputs({"out", "abs_max"})
.SetKernelFn(PD_KERNEL(QuantizeNF4));
2 changes: 2 additions & 0 deletions csrc/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
cupy-cuda116
pybind11
21 changes: 21 additions & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.utils.cpp_extension import CUDAExtension, setup

setup(
name="paddleslim_ops",
ext_modules=CUDAExtension(sources=[
"./lc/nf4.cu",
]), )
26 changes: 13 additions & 13 deletions paddleslim/lc/layers/nf4_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,35 @@
class NF4Linear(WeightQuantizationLinear):
quant_dtype = "int4"
weight_dtype = "int8"
quant_scale_suffix = "quant_scale"
double_quant_scale_suffix = "double_quant_scale"

def __init__(
self,
linear: nn.Linear,
block_size=64,
double_quant=False, ):
use_double_quant=False, ):
super(NF4Linear, self).__init__(linear)
self.block_size = block_size
self.double_quant = double_quant
self.quantizer = NF4Quantizer(block_size, double_quant)
self.double_quant = use_double_quant
self.quantizer = NF4Quantizer(block_size, use_double_quant)
# PaddlePaddle dosen't support Int4 data type, one Int8 data represents two Int4 data.
self.quant_weight = self.create_parameter(
shape=[self.out_features // 2, self.in_features],
attr=paddle.ParamAttr(self.quant_weight_name),
dtype=NF4Linear.weight_dtype,
is_bias=False, )

self.quant_scale_name = ".".join([self.weight_name, "quant_scale"])
self.quant_scale_name = ".".join(
[self.weight_name, NF4Linear.quant_scale_suffix])
self.quant_scale = self.create_parameter(
shape=[self.out_features],
attr=paddle.ParamAttr(self.quant_scale_name),
dtype="float32", # to be fixed
is_bias=False, )
if self.double_quant:
self.double_quant_scale_name = ".".join(
[self.weight_name, "double_quant_scale"])
[self.weight_name, NF4Linear.double_quant_scale_suffix])
self.double_quant_scale = self.create_parameter(
shape=[self.out_features],
attr=paddle.ParamAttr(self.double_quant_scale_name),
Expand All @@ -41,14 +44,11 @@ def __init__(

def quantize(self, weight):
quantized_weight = self.quantizer.quantize(weight)
#self.set_state_dict({self.quant_weight_name: quantized_weight})
self.quant_weight.set_value(quantized_weight)
#self.set_state_dict({self.quant_scale_name: self.quantizer.quant_scale})
self.quant_scale.set_value(self.quantizer.quant_scale)
if self.double_quant:
#self.set_state_dict({self.double_quant_scale_name: self.quantizer.double_quant_scale})
self.double_quant_scale.set_value(self.quantizer.double_quant_scale)
return quantized_weight
return {
self.quant_weight_name: quantized_weight,
self.quant_scale_name: self.quantizer.quant_scale,
self.double_quant_scale_name: self.quantizer.double_quant_scale
}

def forward(self, x):
self.quantizer.quant_scale = self.state_dict[self.quant_scale_name]
Expand Down
5 changes: 4 additions & 1 deletion paddleslim/lc/quantizers/nf4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import paddle
from .base_quantizer import BaseQuantizer
import paddleslim_ops


class NF4Quantizer(BaseQuantizer):
Expand All @@ -13,7 +14,9 @@ def __init__(self, block_size=64, double_quant=False):
self.double_quant_scale = None

def quantize(self, x: paddle.Tensor):
return x
out, abs_max = paddleslim_ops.quantize_nf4(x)
self.quant_scale = abs_max
return out

def dequantize(self, x: paddle.Tensor):
return x
Expand Down

0 comments on commit 3ce762b

Please sign in to comment.