Skip to content

Commit

Permalink
add vllm fp8 w8a8 (per-channel/per-token)
Browse files Browse the repository at this point in the history
  • Loading branch information
baishihao committed Nov 5, 2024
1 parent cd92bf6 commit 253725c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
3 changes: 2 additions & 1 deletion lightllm/common/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
AOFP8W8A16QuantizationMethod,
AOFP6W6A16QuantizationMethod,
)
from .vllm_quant import vLLMw8a8QuantizationMethod
from .vllm_quant import vLLMw8a8QuantizationMethod, vLLMFP8w8a8QuantizationMethod

QUANTIZATION_METHODS = {
"ppl_w4a16": PPLW4A16QuantizationMethod,
Expand All @@ -17,6 +17,7 @@
"ao-fp8w8a16": AOFP8W8A16QuantizationMethod,
"ao-fp6w6a16": AOFP6W6A16QuantizationMethod,
"vllm-w8a8": vLLMw8a8QuantizationMethod,
"vllm-fp8w8a8": vLLMFP8w8a8QuantizationMethod,
}


Expand Down
20 changes: 20 additions & 0 deletions lightllm/common/layers/quantization/vllm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,23 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
)
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
return out


class vLLMFP8w8a8QuantizationMethod(vLLMBaseQuantizationMethod):
def __init__(self):
super().__init__()

def quantize(self, weight: torch.Tensor):
qweight, weight_scale = ops.scaled_fp8_quant(weight.cuda(), scale=None, use_per_token_if_dynamic=True)
return qweight.transpose(0, 1), weight_scale

def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
m = input_tensor.shape[0]
n = weights[0].shape[1]
if out is None:
out = g_cache_manager.alloc_tensor(
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
)
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
return out

0 comments on commit 253725c

Please sign in to comment.