From 253725c0461d42e91a403e6b6f445775aba0d14c Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 5 Nov 2024 11:22:52 +0800 Subject: [PATCH] add vllm fp8 w8a8 (per-channel/per-token) --- .../common/layers/quantization/__init__.py | 3 ++- .../common/layers/quantization/vllm_quant.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/lightllm/common/layers/quantization/__init__.py b/lightllm/common/layers/quantization/__init__.py index 44a91ee2..92e04d63 100644 --- a/lightllm/common/layers/quantization/__init__.py +++ b/lightllm/common/layers/quantization/__init__.py @@ -6,7 +6,7 @@ AOFP8W8A16QuantizationMethod, AOFP6W6A16QuantizationMethod, ) -from .vllm_quant import vLLMw8a8QuantizationMethod +from .vllm_quant import vLLMw8a8QuantizationMethod, vLLMFP8w8a8QuantizationMethod QUANTIZATION_METHODS = { "ppl_w4a16": PPLW4A16QuantizationMethod, @@ -17,6 +17,7 @@ "ao-fp8w8a16": AOFP8W8A16QuantizationMethod, "ao-fp6w6a16": AOFP6W6A16QuantizationMethod, "vllm-w8a8": vLLMw8a8QuantizationMethod, + "vllm-fp8w8a8": vLLMFP8w8a8QuantizationMethod, } diff --git a/lightllm/common/layers/quantization/vllm_quant.py b/lightllm/common/layers/quantization/vllm_quant.py index f50ec5c7..2a9f6cdc 100644 --- a/lightllm/common/layers/quantization/vllm_quant.py +++ b/lightllm/common/layers/quantization/vllm_quant.py @@ -49,3 +49,23 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): ) torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) return out + + +class vLLMFP8w8a8QuantizationMethod(vLLMBaseQuantizationMethod): + def __init__(self): + super().__init__() + + def quantize(self, weight: torch.Tensor): + qweight, weight_scale = ops.scaled_fp8_quant(weight.cuda(), scale=None, use_per_token_if_dynamic=True) + return qweight.transpose(0, 1), weight_scale + + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + m = input_tensor.shape[0] + n = weights[0].shape[1] + if out is None: + out = g_cache_manager.alloc_tensor( + (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False + ) + torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) + return out