diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index e87072bbd..d852d5041 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -439,6 +439,25 @@ impl QMatMul { pub fn from_qtensor(qtensor: QTensor) -> Result { Self::from_arc(std::sync::Arc::new(qtensor)) } + + pub fn dequantize_f16(&self) -> Result { + match self { + Self::QTensor(t) => t.dequantize_f16(&t.device()), + Self::Tensor(t) => t.to_dtype(DType::F16), + Self::TensorF16(t) => Ok(t.clone()), + } + } + + pub fn forward_via_f16(&self, xs: &Tensor) -> Result { + let w = self.dequantize_f16()?; + let in_dtype = xs.dtype(); + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) + } } impl crate::CustomOp1 for QTensor {