Skip to content

Commit

Permalink
Add the cuda dequantize f16 kernels. (#2137)
Browse files Browse the repository at this point in the history
* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
  • Loading branch information
LaurentMazare authored Apr 28, 2024
1 parent c68ed89 commit eb26e24
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 55 deletions.
88 changes: 75 additions & 13 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use half::f16;

use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};

Expand Down Expand Up @@ -59,7 +60,7 @@ fn quantize_q8_1(
Ok(())
}

fn dequantize(
fn dequantize_f32(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
Expand All @@ -69,27 +70,27 @@ fn dequantize(

let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0",
"dequantize_block_q5_0_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1",
"dequantize_block_q5_1_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
Expand All @@ -116,6 +117,63 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}

fn dequantize_f16(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;

let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};

if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}

fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
Expand Down Expand Up @@ -341,7 +399,7 @@ impl QCudaStorage {
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.

Expand Down Expand Up @@ -369,6 +427,10 @@ impl QCudaStorage {
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}

pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
}

pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/quantized/dummy_cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}

pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
Expand Down
47 changes: 43 additions & 4 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;

Expand Down Expand Up @@ -360,9 +360,24 @@ impl QTensor {
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
let is_variable = false;
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
.to_device(device)
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
}

pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
// In the CUDA case, we have a specialized kernel as this can be useful for volta
// architectures. https://github.com/huggingface/candle/issues/2136
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device)
}
_ => {
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
Ok(s)
}
}
}

pub fn storage_size_in_bytes(&self) -> usize {
Expand All @@ -378,6 +393,7 @@ impl QTensor {
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
TensorF16(Tensor),
}

thread_local! {
Expand All @@ -391,6 +407,17 @@ thread_local! {
}
}

thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}

impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
Expand All @@ -400,6 +427,9 @@ impl QMatMul {
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
Expand Down Expand Up @@ -486,6 +516,15 @@ impl crate::Module for QMatMul {
};
xs.matmul(&w)
}
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
}
}
Loading

0 comments on commit eb26e24

Please sign in to comment.