Skip to content

Commit

Permalink
[Model] Add NVFP4 checkpoint support
Browse files Browse the repository at this point in the history
Signed-off-by: Pavani Majety <[email protected]>
  • Loading branch information
pavanimajety committed Jan 28, 2025
1 parent 0f657bd commit c20e87d
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 13 deletions.
80 changes: 80 additions & 0 deletions tests/models/decoder_only/language/test_nvfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# flake8: noqa
"""Tests Model Optimizer fp8 models against ground truth generation
Note: these tests will only pass on H100
"""
import os
from typing import List

import pytest
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024

MODELS = ["nvidia/Llama-3.3-8B-Instruct-FP4"]

EXPECTED_STRS_MAP = {
"nvidia/Llama-3.1-8B-Instruct-FP8": [
"You're referring to VLLM, a high-performance Large Language Model (LLM) inference and",
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir',
'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる'
]
}


# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp8 implementation or
# the hardware being run on.
# Disabled to prevent it from breaking the build
@pytest.mark.skip(
reason=
"Prevent unstable test based on golden strings from breaking the build.")
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp4"),
reason="fp4 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(example_prompts, model_name) -> None:
model = LLM(
model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="nvfp4",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]
params = SamplingParams(max_tokens=20, temperature=0)
generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model

print(model_name, generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def _verify_quantization(self) -> None:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark"
"compressed-tensors", "experts_int8", "quark", "nvfp4"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
Expand Down
23 changes: 17 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,23 @@
logger = init_logger(__name__)

WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod", "QuarkLinearMethod"
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod",
"IPEXGPTQLinearMethod",
"HQQMarlinMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
]


Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"fp8",
"fbgemm_fp8",
"modelopt",
"nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
Expand Down Expand Up @@ -93,7 +94,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .neuron_quant import NeuronQuantConfig
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig
Expand All @@ -106,6 +107,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
Expand Down
Loading

0 comments on commit c20e87d

Please sign in to comment.