Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is it time to rerun the benchmarks? #1639

Open
stas00 opened this issue Oct 12, 2024 · 16 comments
Open

is it time to rerun the benchmarks? #1639

stas00 opened this issue Oct 12, 2024 · 16 comments

Comments

@stas00
Copy link

stas00 commented Oct 12, 2024

Hi SGLang team,

I have just tried SGLang for the first time - and it was probably one of the easiest projects to setup and launch - it literally took me a few minutes to go from 0 to serving - awesome!!! and thank you for making it so easy on the user.

I have just benchmarked vllm=0.6.2 vs sglang=0.3.2 on 2 H100s w/ 8b llama3 and tp=2 and I get vllm slightly faster than sglang performance, yet the benchmark section shows a very different picture. Would it be possible to re-benchmark and tell me if I am missing on some optimization flags to see the results you get - I'm just checking the baseline at the moment - so no quantization and such. Will get there a bit later. FWIW, I have just benchmarked and vllm had a massive throughput speed up made in v0.6.2 over its v0.5 https://x.com/StasBekman/status/1844886291378470966 - which is probably why the benchmark on your site needs a refresher.

Thank you!

Below are the stats and command lines so that it's reproducible by others.

vllm=0.6.2 w/ normal

============ Serving Benchmark Result ============
Successful requests:                     50
Benchmark duration (s):                  5.30
Total input tokens:                      12180
Total generated tokens:                  11255
Request throughput (req/s):              9.43
Output token throughput (tok/s):         2121.93
Total Token throughput (tok/s):          4418.26
---------------Time to First Token----------------
Mean TTFT (ms):                          367.92
Median TTFT (ms):                        375.01
P99 TTFT (ms):                           378.64
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.80
Median TPOT (ms):                        6.53
P99 TPOT (ms):                           6.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.54
Median ITL (ms):                         6.87
P99 ITL (ms):                            8.56
==================================================

vllm=0.6.2 w/ --num-scheduler-steps 8

============ Serving Benchmark Result ============
Successful requests:                     50
Benchmark duration (s):                  4.44
Total input tokens:                      12180
Total generated tokens:                  11249
Request throughput (req/s):              11.27
Output token throughput (tok/s):         2535.33
Total Token throughput (tok/s):          5280.50
---------------Time to First Token----------------
Mean TTFT (ms):                          242.44
Median TTFT (ms):                        231.79
P99 TTFT (ms):                           279.11
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.42
Median TPOT (ms):                        5.82
P99 TPOT (ms):                           12.75
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.78
Median ITL (ms):                         44.70
P99 ITL (ms):                            101.35
==================================================

sglang==0.3.2

============ Serving Benchmark Result ============
Successful requests:                     50
Benchmark duration (s):                  5.47
Total input tokens:                      12180
Total generated tokens:                  11514
Request throughput (req/s):              9.14
Output token throughput (tok/s):         2104.62
Total Token throughput (tok/s):          4330.98
---------------Time to First Token----------------
Mean TTFT (ms):                          240.62
Median TTFT (ms):                        242.29
P99 TTFT (ms):                           326.44
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.84
Median TPOT (ms):                        7.19
P99 TPOT (ms):                           26.74
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.04
Median ITL (ms):                         6.53
P99 ITL (ms):                            10.16
==================================================

the servers

vllm:

python -m vllm.entrypoints.openai.api_server \
    --host 0.0.0.0 --port 9999 \
    --model meta-llama/Meta-Llama-3-8B-Instruct \
    --tokenizer meta-llama/Meta-Llama-3-8B-Instruct \
    --dtype=bfloat16 \
    --seed 42 \
    --gpu_memory_utilization 0.8 \
    --num-scheduler-steps 8 \
    -tp 2

sglang:

python -m sglang.launch_server --port 9999 --tp 2  --model-path meta-llama/Meta-Llama-3-8B-Instruct

the benchmark client

git clone https://github.com/vllm-project/vllm
cd vllm/benchmarks
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
mkdir results
python benchmark_serving.py \
    --backend vllm \
    --model meta-llama/Meta-Llama-3-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --port 9999 \
    --save-result \
    --result-dir results \
    --result-filename test.json \
    --num-prompts 50 \
    --request-rate inf \
    --seed 42
@zhyncs
Copy link
Member

zhyncs commented Oct 12, 2024

Hi @stas00

First of all, thank you for your issue.

Your description reveals several issues, which I will point out here. If you have any questions, we can continue the discussion.

Regarding the figure in the README, you can refer to https://github.com/sgl-project/sglang/tree/main/benchmark/blog_v0_2, which provides a detailed description of versions and reproduction methods. Regarding the performance improvement of vLLM v0.6, we have also conducted a benchmark that can be found at https://github.com/sgl-project/sglang/tree/main/benchmark/benchmark_vllm_060. vLLM v0.6 has indeed improved significantly, but there are some limitations. Whether it's --num-scheduler-steps or --multi-step-stream-outputs, when enabled or disabled, compared to the baseline, TTFT or ITL may worsen. Meanwhile, other metrics improve. You need to understand the tradeoff behind this, rather than being attracted by a single metric.

Meanwhile, there are some issues with the parameters when you benchmark SGLang. For your testing scenario, you should use --disable-radix and --enable-torch-compile. Additionally, for an 8B model, using tp 1 is sufficient, there's no need to use tp 2, etc.

Overall, there are many aspects to consider with benchmarking. Both the configuration of the benchmark and configuration of the server itself can significantly impact the results. We need to focus on the overall performance metrics rather than local ones.

@stas00
Copy link
Author

stas00 commented Oct 14, 2024

Thank you for your reply, Yineng.

Thank you for sharing the vllm==0.6.0 vs sglang benchmark. This is great, and fits right into the OP

Your frontpage shows vllm throughput being much much worse than sglang.

68747470733a2f2f6c6d7379732e6f72672f696d616765732f626c6f672f73676c616e675f6c6c616d61332f38625f7468726f7567687075742e737667

The benchmark you have shared shows that vllm is slightly worse, which is a very different situation. That's why I was suggesting a new visual is needed to show the updated reality.

Please note the first results table I shared doesn't use --num-scheduler-steps and I was comparing apples to apples, since the setup I had to benchmark was using tp=2, I had to benchmark sglang with tp=2 as well.

But let's finish the vllm vs sglang discussion as I wasn't seeking to provoke - was just hoping for a fair representation of vllm as it currently appears to be very inferior on that plot you have published many months ago.

===============================

If I get the resources my intention is to support multiple inference backends in our team's inference framework and switch between them depending on which backend performs better than others in each particular use-case - or because of a better stability.

Let's move to how do I make SGLang shine. Thank you for sharing the tips that I should add --disable-radix and --enable-torch-compile. Tomorrow is a holiday, so I'm looking forward to re-running the benchmarks on Tue.

And it sounds like a very low TTFT is one of the main objectives of SGLang, correct? We currently do mainly offline generation, so TTFT doesn't matter, but it'll become hugely important when later we will be facing the user. But that's why I was benchmarking throughput. And I'm excited to use SGLang for when very low TTFT is crucial.

One other thing I was puzzling over is how could I do outlines-style structured json generation, there I pass the target json schema and everything is taken care of automatically - with SGLang it appears I need to manually create a regex - any reason why this can't be automated? I really liked your blog about switching to prefill when doing structured generation and knowing that the next few tokens are fixed and require no generation and wanted to try it out in practice. Though I am starting to diverge here and should probably start a different thread on this one.

@stas00
Copy link
Author

stas00 commented Oct 15, 2024

I had a chance to rerun the benchmark with --disable-radix and --enable-torch-compile and throughput is worse than w/o those. TTFT is better.

sglang==0.3.2

============ Serving Benchmark Result ============
Successful requests:                     50
Benchmark duration (s):                  5.47
Total input tokens:                      12180
Total generated tokens:                  11514
Request throughput (req/s):              9.14
Output token throughput (tok/s):         2104.62
Total Token throughput (tok/s):          4330.98
---------------Time to First Token----------------
Mean TTFT (ms):                          240.62
Median TTFT (ms):                        242.29
P99 TTFT (ms):                           326.44
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.84
Median TPOT (ms):                        7.19
P99 TPOT (ms):                           26.74
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.04
Median ITL (ms):                         6.53
P99 ITL (ms):                            10.16
==================================================

sglang==0.3.2 + --disable-radix --enable-torch-compile

============ Serving Benchmark Result ============
Successful requests:                     50
Benchmark duration (s):                  6.05
Total input tokens:                      12180
Total generated tokens:                  11505
Request throughput (req/s):              8.27
Output token throughput (tok/s):         1902.84
Total Token throughput (tok/s):          3917.31
---------------Time to First Token----------------
Mean TTFT (ms):                          195.35
Median TTFT (ms):                        170.29
P99 TTFT (ms):                           252.13
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.84
Median TPOT (ms):                        7.82
P99 TPOT (ms):                           18.21
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.71
Median ITL (ms):                         7.44
P99 ITL (ms):                            10.17
==================================================

the baseline command was:

python -m sglang.launch_server --port 9999 --tp 2  --model-path meta-llama/Meta-Llama-3-8B-Instruct

it also took forever to start with --enable-torch-compile

Thoughts?

@zhyncs
Copy link
Member

zhyncs commented Oct 17, 2024

@stas00 If you are concerned about offline scenarios, focusing on throughput, you should maximize the batch size and make full use of VRAM, meaning the KV Cache usage should be as high as possible. However, in your benchmark command, only 50 requests were made. Running Llama 3.1 8B under two H100 devices is far from reaching the true limit of throughput with these 50 requests.

@zhyncs
Copy link
Member

zhyncs commented Oct 17, 2024

I conducted a simple benchmark with 2000 prompts, which you can use as a reference. The commands are generally consistent with what you provided above. Among them, vLLM is version 0.6.3 and SGLang is the latest main version. The startup command for vLLM removed --gpu_memory_utilization 0.8 and added --disable-log-requests.
From the results, it can be seen that in the maximum throughput scenarios you are concerned about, SGLang consistently outperforms vLLM. (24022.63 vs 22116.03)

# SGLang
============ Serving Benchmark Result ============
Successful requests:                     2000
Benchmark duration (s):                  34.55
Total input tokens:                      453502
Total generated tokens:                  376493
Request throughput (req/s):              57.89
Output token throughput (tok/s):         10896.87
Total Token throughput (tok/s):          24022.63
---------------Time to First Token----------------
Mean TTFT (ms):                          4878.89
Median TTFT (ms):                        4804.10
P99 TTFT (ms):                           8589.78
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          194.03
Median TPOT (ms):                        74.07
P99 TPOT (ms):                           1647.41
---------------Inter-token Latency----------------
Mean ITL (ms):                           78.71
Median ITL (ms):                         34.42
P99 ITL (ms):                            518.89
==================================================

# vLLM
============ Serving Benchmark Result ============
Successful requests:                     2000
Benchmark duration (s):                  37.58
Total input tokens:                      453502
Total generated tokens:                  377607
Request throughput (req/s):              53.22
Output token throughput (tok/s):         10048.22
Total Token throughput (tok/s):          22116.03
---------------Time to First Token----------------
Mean TTFT (ms):                          15410.77
Median TTFT (ms):                        15314.09
P99 TTFT (ms):                           32016.97
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          30.23
Median TPOT (ms):                        28.58
P99 TPOT (ms):                           78.09
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.85
Median ITL (ms):                         19.51
P99 ITL (ms):                            273.56
==================================================

@stas00
Copy link
Author

stas00 commented Oct 17, 2024

This is very useful, @zhyncs

  1. I'm able to reproduce that SGLang consistently outperforms vLLM - I even got better throughput than you did - but do we agree that the difference isn't 2x as it appears to be from the benchmark on the front page? (bringing back to the OP)

  2. you suggested to add 2 flags --disable-radix --enable-torch-compile but in my benchmarking they make the results worse - any suggestions if perhaps I'm doing wrong?

  3. looks like vllm had a speed regression 0.6.2 => 0.6.3 [Performance]: speed regression 0.6.2 => 0.6.3? vllm-project/vllm#9476 - I get a better throughput with 0.6.2, though it's still behind SGLang.

Here are the benchmark results:

2000 concurrent requests benchmark

Benchmark

git clone https://github.com/vllm-project/vllm
cd vllm/benchmarks
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
mkdir results
python benchmark_serving.py \
    --backend vllm \
    --model meta-llama/Meta-Llama-3-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --port 9999 \
    --save-result \
    --result-dir results \
    --result-filename test.json \
    --num-prompts 2000 \
    --request-rate inf \
    --seed 42

sglang==0.3.2

python -m sglang.launch_server --host $(hostname) --host 0.0.0.0 --port 9999 --tp 2  --model-path  meta-llama/Meta-Llama-3-8B-Instruct

============ Serving Benchmark Result ============
Successful requests:                     2000
Benchmark duration (s):                  30.80
Total input tokens:                      453502
Total generated tokens:                  376822
Request throughput (req/s):              64.94
Output token throughput (tok/s):         12235.46
Total Token throughput (tok/s):          26960.73
---------------Time to First Token----------------
Mean TTFT (ms):                          2724.71
Median TTFT (ms):                        2866.76
P99 TTFT (ms):                           3483.93
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          85.08
Median TPOT (ms):                        51.94
P99 TPOT (ms):                           367.09
---------------Inter-token Latency----------------
Mean ITL (ms):                           56.85
Median ITL (ms):                         34.19
P99 ITL (ms):                            419.42
==================================================

sglang==0.3.2 + --disable-radix --enable-torch-compile

python -m sglang.launch_server --host $(hostname) --host 0.0.0.0 --port 9999 --tp 2  --model-path  meta-llama/Meta-Llama-3-8B-Instruct \
--disable-radix --enable-torch-compile


============ Serving Benchmark Result ============
Successful requests:                     2000
Benchmark duration (s):                  39.77
Total input tokens:                      453502
Total generated tokens:                  376435
Request throughput (req/s):              50.29
Output token throughput (tok/s):         9464.84
Total Token throughput (tok/s):          20867.41
---------------Time to First Token----------------
Mean TTFT (ms):                          5574.30
Median TTFT (ms):                        5403.57
P99 TTFT (ms):                           9824.15
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          208.52
Median TPOT (ms):                        84.46
P99 TPOT (ms):                           1794.57
---------------Inter-token Latency----------------
Mean ITL (ms):                           84.73
Median ITL (ms):                         39.16
P99 ITL (ms):                            908.29
==================================================

vllm==0.6.2

python -m vllm.entrypoints.openai.api_server \
    --host 0.0.0.0 --port 9999 \
    --model meta-llama/Meta-Llama-3-8B-Instruct \
    --tokenizer meta-llama/Meta-Llama-3-8B-Instruct \
    --dtype=bfloat16 \
    --seed 42 \
    --num-scheduler-steps 8 \
    --disable-log-requests \
    -tp 2

============ Serving Benchmark Result ============
Successful requests:                     2000
Benchmark duration (s):                  37.56
Total input tokens:                      453502
Total generated tokens:                  377235
Request throughput (req/s):              53.24
Output token throughput (tok/s):         10042.39
Total Token throughput (tok/s):          22115.08
---------------Time to First Token----------------
Mean TTFT (ms):                          13418.88
Median TTFT (ms):                        13693.80
P99 TTFT (ms):                           27527.70
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.64
Median TPOT (ms):                        18.37
P99 TPOT (ms):                           74.99
---------------Inter-token Latency----------------
Mean ITL (ms):                           137.43
Median ITL (ms):                         136.93
P99 ITL (ms):                            506.59
==================================================

vllm==0.6.3.post1

python -m vllm.entrypoints.openai.api_server \
    --host 0.0.0.0 --port 9999 \
    --model meta-llama/Meta-Llama-3-8B-Instruct \
    --tokenizer meta-llama/Meta-Llama-3-8B-Instruct \
    --dtype=bfloat16 \
    --seed 42 \
    --num-scheduler-steps 8 \
    --disable-log-requests \
    -tp 2


============ Serving Benchmark Result ============
Successful requests:                     2000
Benchmark duration (s):                  43.14
Total input tokens:                      453502
Total generated tokens:                  378114
Request throughput (req/s):              46.36
Output token throughput (tok/s):         8764.77
Total Token throughput (tok/s):          19277.04
---------------Time to First Token----------------
Mean TTFT (ms):                          16705.17
Median TTFT (ms):                        15436.29
P99 TTFT (ms):                           36631.51
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          36.76
Median TPOT (ms):                        33.73
P99 TPOT (ms):                           110.12
---------------Inter-token Latency----------------
Mean ITL (ms):                           30.63
Median ITL (ms):                         23.22
P99 ITL (ms):                            318.24
==================================================

@zhyncs
Copy link
Member

zhyncs commented Oct 17, 2024

The radix prefix cache has overhead for SGLang. We've recently worked to reduce this, but a new version hasn't been released yet. Therefore, I recommend adding --disable-radix for benchmark. When running small models like Llama 8b, I suggest enabling torch compile. It outperforms gpt-fast and is beneficial for latency-sensitive scenarios. Ultimately, the choice is yours.

@zhyncs
Copy link
Member

zhyncs commented Oct 17, 2024

ref https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/the_first_sglang_online_meetup.pdf

@stas00
Copy link
Author

stas00 commented Oct 17, 2024

You can see my results above - either or both --disable-radix --enable-torch-compile impact throughput for the worse.

@zhyncs
Copy link
Member

zhyncs commented Oct 18, 2024

--disable-radix theoretically does not reduce throughput. In my impression, when enabling radix cache in the ShareGPT dataset, the benefits of cache sharing are less than the overhead itself.

@zhyncs
Copy link
Member

zhyncs commented Oct 25, 2024

Hi @stas00 I would like to share with you another perspective on how I make products.

============ Serving Benchmark Result ============
Backend:                                 trt
Traffic request rate:                    46.0
Successful requests:                     1000
Benchmark duration (s):                  31.14
Total input tokens:                      224442
Total generated tokens:                  190594
Total generated tokens (retokenized):    189883
Request throughput (req/s):              32.11
Input token throughput (tok/s):          7206.88
Output token throughput (tok/s):         6120.02
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2703.69
Median E2E Latency (ms):                 1710.39
---------------Time to First Token----------------
Mean TTFT (ms):                          39.01
Median TTFT (ms):                        33.00
P99 TTFT (ms):                           143.08
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.35
Median TPOT (ms):                        14.02
P99 TPOT (ms):                           22.11
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.05
Median ITL (ms):                         11.11
P99 ITL (ms):                            88.29
==================================================

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    46.0
Successful requests:                     1000
Benchmark duration (s):                  33.24
Total input tokens:                      224442
Total generated tokens:                  190594
Total generated tokens (retokenized):    190555
Request throughput (req/s):              30.09
Input token throughput (tok/s):          6752.44
Output token throughput (tok/s):         5734.10
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   4299.44
Median E2E Latency (ms):                 2941.98
---------------Time to First Token----------------
Mean TTFT (ms):                          50.22
Median TTFT (ms):                        44.58
P99 TTFT (ms):                           271.68
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.57
Median TPOT (ms):                        24.47
P99 TPOT (ms):                           45.20
---------------Inter-token Latency----------------
Mean ITL (ms):                           22.83
Median ITL (ms):                         16.09
P99 ITL (ms):                            89.85
==================================================

============ Serving Benchmark Result ============
Backend:                                 vllm
Traffic request rate:                    46.0
Successful requests:                     1000
Benchmark duration (s):                  43.54
Total input tokens:                      224442
Total generated tokens:                  190594
Total generated tokens (retokenized):    190565
Request throughput (req/s):              22.97
Input token throughput (tok/s):          5155.04
Output token throughput (tok/s):         4377.61
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   9189.69
Median E2E Latency (ms):                 7653.53
---------------Time to First Token----------------
Mean TTFT (ms):                          1810.25
Median TTFT (ms):                        177.19
P99 TTFT (ms):                           6712.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.26
Median TPOT (ms):                        45.88
P99 TPOT (ms):                           71.70
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.91
Median ITL (ms):                         27.71
P99 ITL (ms):                            271.40
==================================================

Running Llama 3.1 8B Instruct TP 1 on H100, using the open-source version of TensorRT LLM v0.13.0, SGLang's latest commit and vLLM 0.6.3.post1. The goal is TTFT P99 less than 200ms, benchmark with 1k prompts, TensorRT LLM can achieve a maximum request rate of 46 before higher latency fails to meet requirements. Using the same benchmark configuration for SGLang yielded similar results as above. In short, TensorRT LLM indeed has certain advantages for latency-sensitive online scenarios. Users often require meeting specific latency demands. For example, if their service maxes out at 400ms, then they typically allocate around P99 TTFT of about 200ms for LLM Inference Serving (just a simple example and not necessarily accurate). Under conditions that satisfy latency requirements, higher throughput is better.

Note: TensorRT LLM v0.13.0, SGLang latest main, vLLM v0.6.3.post1

python3 -m sglang.bench_serving --backend trt --request-rate 46 --model Llama-3.1-8B-Instruct
python3 -m sglang.bench_serving --backend sglang --request-rate 46
python3 -m sglang.bench_serving --backend vllm --request-rate 46

@zhyncs
Copy link
Member

zhyncs commented Oct 25, 2024

python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --enable-overlap-schedule

python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --max_model_len 2048

@stas00
Copy link
Author

stas00 commented Oct 25, 2024

That's very cool, @zhyncs! Thank you for running additional benchmarks.

Adding --num-scheduler-steps 8 --multi-step-stream-outputs=False to vllm should make things faster in vllm 0.6.3+ according to this thread vllm-project/vllm#9476

@zhyncs
Copy link
Member

zhyncs commented Oct 25, 2024

python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --max_model_len 2048 --num-scheduler-steps 8 --multi-step-stream-outputs=False
python3 -m sglang.bench_serving --backend vllm --request-rate 46
============ Serving Benchmark Result ============
Backend:                                 vllm
Traffic request rate:                    46.0
Successful requests:                     1000
Benchmark duration (s):                  32.53
Total input tokens:                      224442
Total generated tokens:                  190594
Total generated tokens (retokenized):    190565
Request throughput (req/s):              30.74
Input token throughput (tok/s):          6899.73
Output token throughput (tok/s):         5859.18
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   3663.89
Median E2E Latency (ms):                 2412.93
---------------Time to First Token----------------
Mean TTFT (ms):                          92.97
Median TTFT (ms):                        82.02
P99 TTFT (ms):                           289.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.66
Median TPOT (ms):                        19.51
P99 TPOT (ms):                           36.61
---------------Inter-token Latency----------------
Mean ITL (ms):                           147.97
Median ITL (ms):                         139.65
P99 ITL (ms):                            369.77
==================================================

@stas00 Thanks for your advice. Cool It gets better. And the performance is still worse than TensorRT LLM and SGLang.

@stas00
Copy link
Author

stas00 commented Oct 25, 2024

The TTFT is indeed much worse, but vllm's throughput is higher than SGLang's according to the updated numbers you have shared.

And cycling back to the OP, it's hopefully very clear now that the plots on the SGLang's front page need a refresh to bring them to the up-to-date reality ;)

@zhyncs
Copy link
Member

zhyncs commented Oct 25, 2024

As described above #1639 (comment) I use this benchmark config to test the online scenario. If we want to test the max throughput for offline, we should maximize the batch size and make full use of VRAM, meaning the KV Cache usage should be as high as possible. At this time, SGLang's throughput is also higher.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants