From 78597190dce6c789d0bf28996747bc2c82ec7641 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 28 Apr 2024 20:25:40 +0200 Subject: [PATCH] Add a forward_via_f16 method to the qmatmul op. --- candle-core/src/quantized/mod.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index e87072bbdb..d852d50410 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -439,6 +439,25 @@ impl QMatMul { pub fn from_qtensor(qtensor: QTensor) -> Result { Self::from_arc(std::sync::Arc::new(qtensor)) } + + pub fn dequantize_f16(&self) -> Result { + match self { + Self::QTensor(t) => t.dequantize_f16(&t.device()), + Self::Tensor(t) => t.to_dtype(DType::F16), + Self::TensorF16(t) => Ok(t.clone()), + } + } + + pub fn forward_via_f16(&self, xs: &Tensor) -> Result { + let w = self.dequantize_f16()?; + let in_dtype = xs.dtype(); + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) + } } impl crate::CustomOp1 for QTensor {