Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM Integration of FusedMoE #17

Open
robertgshaw2-redhat opened this issue Apr 8, 2024 · 11 comments
Open

vLLM Integration of FusedMoE #17

robertgshaw2-redhat opened this issue Apr 8, 2024 · 11 comments

Comments

@robertgshaw2-redhat
Copy link

robertgshaw2-redhat commented Apr 8, 2024

FYI - ongoing PR to integrate this great work into vLLM (vllm-project/vllm#3905)

I ran into a couple correctness issues with a few shapes running through our CI

To Repo:

git clone https://github.com/neuralmagic/nm-vllm.git
cd vllm
git checkout fused-moe
pip install -e .
pip install -r requirements-dev.txt
pytest -v  tests/kernels/test_moe.py::test_fused_moe
@robertgshaw2-redhat robertgshaw2-redhat changed the title vLLM Integration vLLM Integration of FusedMoE Apr 8, 2024
@robertgshaw2-redhat
Copy link
Author

Note: I am able to get all tests passing in vllm if I only run the following shapes:

@pytest.mark.parametrize("m", [2, 4, 8, 16, 32, 64, 128, 512, 1024, 2048])
@pytest.mark.parametrize("n", [14336//2])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float16])

This matches the test suite of this repo

The test suite we have in vllm runs the following:

@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])

@AdnanHoque
Copy link
Contributor

Hi @robertgshaw2-neuralmagic, thanks for running these! Can you try the V1 kernel we have in the meantime, while we debug this?

The V1 kernel features the SplitK Optimization: https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/inference/col_major_moe_gemm/v1_moe_fused.py

@robertgshaw2-redhat
Copy link
Author

Thanks for taking a look to fix them. Let me know where you end up

to be clear the v1 kernels I should see ~10-20% speedup right?

@AdnanHoque
Copy link
Contributor

Indeed! 18% on A100 and 20% on H100 is what we measured, with the parameters already present in V1.

You might be able to squeeze out more if a "split_k" parameter is added here: https://github.com/vllm-project/vllm/blob/f46864d68dfb46ff88f574e6844f10fdb14cd3b5/benchmarks/kernels/benchmark_mixtral_moe.py#L24

That would run through a SplitK setting of 2 - 32, and pick the best SplitK and block size parameters based on the problem size and GPU spec.

@robertgshaw2-redhat
Copy link
Author

Does it ever make sense to use the SplitK over the ColumnMajor?

@AdnanHoque
Copy link
Contributor

AdnanHoque commented Apr 8, 2024

Short Answer:
No, not if we are able to figure out the ColumnMajor fix :)

Long Answer:
Ideally we can use both in tandem, as they are both "inference" optimizations. However, adding in ColumnMajor + SplitK proved to be a little more tricky. There may be a GEMM schedule that is extremely well optimized for a SplitK work decomposition that has the column major properties i.e. maximum reuse of the weight matrix B, while also optimally saturating the SMs of the GPU when M < N, K.

@robertgshaw2-redhat
Copy link
Author

I see, we can add an oracle for instance if SplitK is better in a low M regime

@robertgshaw2-redhat
Copy link
Author

Anyways thanks for your great work! Ill check back in in a day or so to see how progress goes

@robertgshaw2-redhat
Copy link
Author

Indeed making the change suggested here #16 gets the tests to pass

@AdnanHoque
Copy link
Contributor

AdnanHoque commented Apr 10, 2024

Yeah -- it appears that there is a bug with the column-major code and with the fix it appears the speedup from this optimization is gone, which is quite unfortunate.

In this case we can turn to SplitK (V1) (although bfloat16 is not supported as of yet for tl.atomic_add) as I think you'll be able to see some gain from it. Looks like it was able to provide some speedup on the FP8 side of things: vllm-project/vllm#3954

I'm still unclear why it's able to pass the cases that we use - so will need to dig a bit deeper.

@robertgshaw2-redhat
Copy link
Author

robertgshaw2-redhat commented Apr 10, 2024

That's too bad!

I also noticed that the results are very close in the cases I created as well).

Please ping me if you make any more progress. I will try out the SplitK this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants