Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] [Feature] [Doc] [Dockerfile] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing #12501

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jan 28, 2025

Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing

Note: This PR feature requires ROCm 6.3 and later and GPU Arch MI300 and later.

Description

This PR involves the following enhancements

  1. This is a PR specific to support Per-Token-Activation Per-Channel-Weight (PTPC-FP8) FP8 Quantization Inferencing.
    The model will be quantized on-the-fly from BFloat16 to FP8. Model weight which are store in Float16 will need to be casted into BFloat16.

  2. It used PyTorch latest rowwise scaled GEMM feature in torch._scaled_mm which is introduced in [ROCm] hipblaslt rowwise f8 gemm pytorch/pytorch#144432 , which speeds up current naive implementation by at least 2 times. For more details check out the Performance section

  • To support this feature, the Dockerfile.rocm_base PyTorch repo commit has been updated to 3a585126.
  • Dockerfile.rocm is left untouched as the base image is referencing to AMD docker hub registry. That base image at this point in time has already installed with PyTorch repo commit 3a585126.
  1. Small enhancement. The documentation has been updated to ROCm 6.3 and various commits in the installation step has been updated to match the commits in Dockerfile.rocm_base.

Performance

Perplexity Test

Model: Llama-3.1-8B-Instruct
Dataset: Wikitexts
GPU: MI300X

Model Quantization KVCacheDtype Tasks Metric Metric Score
Llama-3.1-8B-Instruct/ auto (bf16) auto (bf16) wikitext word_perplexity 9.4281
Llama-3.1-8B-Instruct/ fp8 fp8_e4m3 wikitext word_perplexity 9.5124
Llama-3.1-8B-Instruct/ ptpc_fp8 fp8_e4m3 wikitext word_perplexity 9.5093
Llama-3.1-8B-Instruct/ ptpc_fp8 (naive) fp8_e4m3 wikitext word_perplexity 9.5095

Speed Test (Old naive implementation vs torch._scaled_mm rowwise scaled GEMM feature)

Model: Llama-3.1-70B-Instruct
Dataset: SharedGPT
GPU: 1xMI300X

Quantization KVCacheDType Req/s Total token/s Output tokens/s
ptpc_fp8 (naive) fp8_e4m3 2.43 1003.46 481.28
ptpc_fp8 (torch._scaled_mm rowwise scaled GEMM feature) fp8_e4m3 6.36 2631.04 1261.91

PTPC_FP8 (naive)


  # Making sure the dummy tensor is on the same device as the weight
  global TORCH_DEVICE_IDENTITY
  if TORCH_DEVICE_IDENTITY.device != weight.device:
      TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)

  # GEMM
  # This computes C = (X * W).
  # Output in fp32 to allow subsequent ops to happen in-place
  output = torch._scaled_mm(qinput,
                            weight,
                            scale_a=TORCH_DEVICE_IDENTITY,
                            scale_b=TORCH_DEVICE_IDENTITY,
                            out_dtype=torch.float32)
  # A fix for discrepancy in scaled_mm which returns tuple
  # for torch < 2.5 and a single value in torch >= 2.5
  if type(output) is tuple and len(output) == 2:
      output = output[0]
  # Unpad (undo num_token_padding)
  output = torch.narrow(output, 0, 0, input_2d.shape[0])
  x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])

  # DQ
  # C = sw * sx * (X * W) + bias
  output = output * x_scale * weight_scale.t()

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: tjtanaa <[email protected]>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. Do you think it would be possible to implement this inside of fp8.py? It seems we could just change the default --quantization fp8 on an unquantized model to use per-token and per-channel. Given the cutlass and pytorch support we have now, I don't think there is a great reason to rely on per-tensor by default anymore

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jan 29, 2025

Thanks for your contribution. Do you think it would be possible to implement this inside of fp8.py? It seems we could just change the default --quantization fp8 on an unquantized model to use per-token and per-channel. Given the cutlass and pytorch support we have now, I don't think there is a great reason to rely on per-tensor by default anymore

@mgoin We think it should be possible to implement PTPC-FP8 inside of fp8.py and force the --quantization fp8 on an unquantized model to use per-token and per-channel on ROCm. The default behavior of --quantization fp8 on NVIDIA GPU would require an additional PR to resolve it.

@hongxiayang @mgoin Maybe we could get the input from AMD to check if there is any preference or demand in maintaining a per-tensor quantization for backward compatibility.

Since we are on this topic, I remember making vLLM production ready is a goal, I wonder if there is a need for us to maintain certain backward compatibility so that the behavior of features does not change as much as possible?
Moreover, when we were adding this new quantization feature and wanted to add documentation about the quantization feature e.g. behavior, usage, expectation, we couldn't find a page for it. I wonder if there is any RFC for documentation about quantization approach?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants