diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 5481ca3c25..8e4884b28d 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage}; use crate::quantized::k_quants::GgmlType; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{CudaDevice, CudaStorage, Result}; +use half::f16; use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; @@ -59,7 +60,7 @@ fn quantize_q8_1( Ok(()) } -fn dequantize( +fn dequantize_f32( data: &CudaSlice, dtype: GgmlDType, elem_count: usize, @@ -69,27 +70,27 @@ fn dequantize( let nb = (elem_count + 255) / 256; let (kernel_name, is_k, block_dim, num_blocks) = match dtype { - GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb), - GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb), + GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb), GgmlDType::Q5_0 => ( - "dequantize_block_q5_0", + "dequantize_block_q5_0_f32", false, CUDA_DEQUANTIZE_BLOCK_SIZE, ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), ), GgmlDType::Q5_1 => ( - "dequantize_block_q5_1", + "dequantize_block_q5_1_f32", false, CUDA_DEQUANTIZE_BLOCK_SIZE, ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), ), - GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb), - GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb), - GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb), - GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb), - GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb), - GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb), - GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb), + GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; @@ -116,6 +117,63 @@ fn dequantize( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } +fn dequantize_f16( + data: &CudaSlice, + dtype: GgmlDType, + elem_count: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let nb = (elem_count + 255) / 256; + let (kernel_name, is_k, block_dim, num_blocks) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb), + GgmlDType::Q5_0 => ( + "dequantize_block_q5_0_f16", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q5_1 => ( + "dequantize_block_q5_1_f16", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb), + _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(elem_count).w()? }; + // See e.g. + // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), + shared_mem_bytes: 0, + }; + + if is_k { + let params = (data, &dst); + unsafe { func.launch(cfg, params) }.w()?; + } else { + let nb32 = match dtype { + GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, + _ => elem_count / 32, + }; + let params = (data, &dst, nb32 as i32); + unsafe { func.launch(cfg, params) }.w()?; + } + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + fn dequantize_mul_mat_vec( data: &CudaSlice, y: &CudaView, @@ -341,7 +399,7 @@ impl QCudaStorage { | GgmlDType::Q8K ); if fast_kernel { - return dequantize(&self.data, self.dtype, elem_count, self.device()); + return dequantize_f32(&self.data, self.dtype, elem_count, self.device()); } // Run the dequantization on cpu. @@ -369,6 +427,10 @@ impl QCudaStorage { .storage_from_cpu_storage(&crate::CpuStorage::F32(out)) } + pub fn dequantize_f16(&self, elem_count: usize) -> Result { + dequantize_f16(&self.data, self.dtype, elem_count, self.device()) + } + pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { // Run the quantization on cpu. let src = match &src.slice { diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index 598c5cd131..ca7b812084 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -24,6 +24,10 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn dequantize_f16(&self, _elem_count: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 47307f2e70..e87072bbdb 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; +use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; use k_quants::*; use std::borrow::Cow; @@ -360,9 +360,24 @@ impl QTensor { pub fn dequantize(&self, device: &Device) -> Result { let storage = self.storage.dequantize(self.shape.elem_count())?; let none = crate::op::BackpropOp::none(); - let is_variable = false; - crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) - .to_device(device) + crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device) + } + + pub fn dequantize_f16(&self, device: &Device) -> Result { + // In the CUDA case, we have a specialized kernel as this can be useful for volta + // architectures. https://github.com/huggingface/candle/issues/2136 + match &self.storage { + QStorage::Cuda(s) => { + let s = s.dequantize_f16(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false) + .to_device(device) + } + _ => { + let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?; + Ok(s) + } + } } pub fn storage_size_in_bytes(&self) -> usize { @@ -378,6 +393,7 @@ impl QTensor { pub enum QMatMul { QTensor(std::sync::Arc), Tensor(Tensor), + TensorF16(Tensor), } thread_local! { @@ -391,6 +407,17 @@ thread_local! { } } +thread_local! { + static DEQUANTIZE_ALL_F16: bool = { + match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + impl QMatMul { pub fn from_arc(qtensor: std::sync::Arc) -> Result { let dequantize = match qtensor.dtype() { @@ -400,6 +427,9 @@ impl QMatMul { let t = if dequantize { let tensor = qtensor.dequantize(&qtensor.device())?; Self::Tensor(tensor) + } else if DEQUANTIZE_ALL_F16.with(|b| *b) { + let tensor = qtensor.dequantize_f16(&qtensor.device())?; + Self::TensorF16(tensor) } else { Self::QTensor(qtensor) }; @@ -486,6 +516,15 @@ impl crate::Module for QMatMul { }; xs.matmul(&w) } + Self::TensorF16(w) => { + let in_dtype = xs.dtype(); + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) + } } } } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index b2a64ac9cd..8011333cae 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -3,7 +3,7 @@ use candle_core::{ quantized::{self, GgmlDType}, test_device, test_utils::to_vec2_round, - Device, IndexOp, Module, Result, Tensor, + DType, Device, IndexOp, Module, Result, Tensor, }; use quantized::{k_quants, GgmlType}; use rand::prelude::*; @@ -225,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> { let src = Tensor::from_slice(&src, (32 * 4,), device)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); assert_eq!( dst.to_vec1::()?, &[ @@ -251,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> { let src = Tensor::from_slice(&src, (32 * 4,), device)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); assert_eq!( round_vector(&dst.to_vec1::()?), &[ @@ -277,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> { let src = Tensor::from_slice(&src, (32 * 4,), device)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); assert_eq!( round_vector(&dst.to_vec1::()?), &[ @@ -303,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> { let src = Tensor::from_slice(&src, (32 * 4,), device)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); assert_eq!( round_vector(&dst.to_vec1::()?), &[ @@ -387,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3 let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let error = calculate_rmse(&src.to_vec1::()?, &dst.to_vec1::()?); if error > max_error { bail!( @@ -404,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> { let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src = src.to_vec1::()?; let dst = dst.to_vec1::()?; @@ -423,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> { let src_big = get_test_vector2(128.0, 1024, device)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src_big = src_big.to_vec1::()?; let dst_big = dst_big.to_vec1::()?; @@ -437,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> { let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src = src.to_vec1::()?; let dst = dst.to_vec1::()?; @@ -456,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> { let src_big = get_test_vector2(128.0, 1024, device)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src_big = src_big.to_vec1::()?; let dst_big = dst_big.to_vec1::()?; @@ -470,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> { let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src = src.to_vec1::()?; let dst = dst.to_vec1::()?; @@ -489,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> { let src_big = get_test_vector2(128.0, 1024, device)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src_big = src_big.to_vec1::()?; let dst_big = dst_big.to_vec1::()?; @@ -503,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> { let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src = src.to_vec1::()?; let dst = dst.to_vec1::()?; @@ -522,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> { let src_big = get_test_vector2(128.0, 1024, device)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src_big = src_big.to_vec1::()?; let dst_big = dst_big.to_vec1::()?; @@ -536,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> { let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src = src.to_vec1::()?; let dst = dst.to_vec1::()?; @@ -555,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> { let src_big = get_test_vector2(128.0, 1024, device)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src_big = src_big.to_vec1::()?; let dst_big = dst_big.to_vec1::()?; @@ -569,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> { let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src = src.to_vec1::()?; let dst = dst.to_vec1::()?; @@ -588,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> { let src_big = get_test_vector2(128.0, 1024, device)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); let src_big = src_big.to_vec1::()?; let dst_big = dst_big.to_vec1::()?; diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index c5bc45630f..05f878f3d6 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -765,20 +765,21 @@ static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __ y[iybs + iqs + y_offset] = v.y; } -extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { +template +static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } - float * y = yy + 256*i + 32*ir + 4*il; + dst_t * y = yy + 256*i + 32*ir + 4*il; const block_q4_0 * x = (const block_q4_0 *)vx + ib; const float d = __half2float(x->d); @@ -792,20 +793,21 @@ extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, f } } -extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { +template +static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } - float * y = yy + 256*i + 32*ir + 4*il; + dst_t * y = yy + 256*i + 32*ir + 4*il; const block_q4_1 * x = (const block_q4_1 *)vx + ib; const float2 d = __half22float2(x->dm); @@ -820,7 +822,8 @@ extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, f //================================== k-quants -extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; const block_q2_K * x = (const block_q2_K *) vx; @@ -832,7 +835,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f const int is = 8*n + l/16; const uint8_t q = x[i].qs[32*n + l]; - float * y = yy + i*QK_K + 128*n; + dst_t * y = yy + i*QK_K + 128*n; float dall = __low2half(x[i].dm); float dmin = __high2half(x[i].dm); @@ -844,7 +847,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f const int is = tid/16; // 0 or 1 const int il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); - float * y = yy + i*QK_K + 16*is + il; + dst_t * y = yy + i*QK_K + 16*is + il; float dall = __low2half(x[i].dm); float dmin = __high2half(x[i].dm); y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); @@ -853,7 +856,8 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f } -extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; @@ -877,7 +881,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f float d_all = x[i].d; float dl = d_all * (us - 32); - float * y = yy + i*QK_K + 128*n + 32*j; + dst_t * y = yy + i*QK_K + 128*n + 32*j; const uint8_t * q = x[i].qs + 32*n; const uint8_t * hm = x[i].hmask; @@ -889,7 +893,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f const int im = il/8; // 0...1 const int in = il%8; // 0...7 - float * y = yy + i*QK_K + 16*is + il; + dst_t * y = yy + i*QK_K + 16*is + il; const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t h = x[i].hmask[in] >> (2*is + im); @@ -917,7 +921,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t } #endif -extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q4_K * x = (const block_q4_K *) vx; const int i = blockIdx.x; @@ -930,7 +935,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f const int is = 2*il; const int n = 4; - float * y = yy + i*QK_K + 64*il + n*ir; + dst_t * y = yy + i*QK_K + 64*il + n*ir; const float dall = __low2half(x[i].dm); const float dmin = __high2half(x[i].dm); @@ -949,7 +954,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f #else const int tid = threadIdx.x; const uint8_t * q = x[i].qs; - float * y = yy + i*QK_K; + dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; const float m = (float)x[i].dm[1]; y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); @@ -957,7 +962,8 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f #endif } -extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q5_K * x = (const block_q5_K *) vx; const int i = blockIdx.x; @@ -969,7 +975,7 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f const int ir = tid%16; // ir is in 0...15 const int is = 2*il; // is is in 0...6 - float * y = yy + i*QK_K + 64*il + 2*ir; + dst_t * y = yy + i*QK_K + 64*il + 2*ir; const float dall = __low2half(x[i].dm); const float dmin = __high2half(x[i].dm); @@ -997,25 +1003,26 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f const int is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; const float d = x[i].d; - float * y = yy + i*QK_K + tid; + dst_t * y = yy + i*QK_K + tid; y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); #endif } -extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q6_K * x = (const block_q6_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; + const int64_t tid = threadIdx.x; + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; - float * y = yy + i*QK_K + 128*ip + il; + dst_t * y = yy + i*QK_K + 128*ip + il; const float d = x[i].d; @@ -1030,11 +1037,11 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f #else // assume 32 threads - const int tid = threadIdx.x; - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 + const int64_t tid = threadIdx.x; + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 - float * y = yy + i*QK_K + 16*ip + il; + dst_t * y = yy + i*QK_K + 16*ip + il; const float d = x[i].d; @@ -1047,7 +1054,8 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f #endif } -extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { +template +static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { const int i = blockIdx.x; // assume 32 threads @@ -1059,7 +1067,7 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f return; } - float * y = yy + 256*i + 32*ir + 8*il; + dst_t * y = yy + 256*i + 32*ir + 8*il; const block_q8_0 * x = (const block_q8_0 *)vx + ib; const float d = __half2float(x->d); @@ -1071,7 +1079,8 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f } } -extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) { +template +static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q8_K * x = (const block_q8_K *) vx; const int i = blockIdx.x; @@ -1083,7 +1092,7 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f const int ir = tid%8; const int n = 8; - float * y = yy + i*QK_K + 64*il + n*ir; + dst_t * y = yy + i*QK_K + 64*il + n*ir; const int8_t * q = x[i].qs + 64*il + n*ir; @@ -1098,14 +1107,43 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f #endif } -extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { +template +static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { return dequantize_block(vx, yy, nb32); } -extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { +template +static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { return dequantize_block(vx, yy, nb32); } +#define DEQUANTIZE_K(QNAME) \ +extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \ + dequantize_block_##QNAME(vx, y); \ +} \ +extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \ + dequantize_block_##QNAME(vx, y); \ +} \ + +#define DEQUANTIZE(QNAME) \ +extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \ + dequantize_block_##QNAME(vx, y, k); \ +} \ +extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \ + dequantize_block_##QNAME(vx, y, k); \ +} \ + +DEQUANTIZE_K(q2_K) +DEQUANTIZE_K(q3_K) +DEQUANTIZE_K(q4_K) +DEQUANTIZE_K(q5_K) +DEQUANTIZE_K(q6_K) +DEQUANTIZE_K(q8_K) +DEQUANTIZE(q4_0) +DEQUANTIZE(q4_1) +DEQUANTIZE(q5_0) +DEQUANTIZE(q5_1) +DEQUANTIZE(q8_0) template static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {