Skip to content

Conversation

@diegocastanibm
Copy link

@diegocastanibm diegocastanibm commented Aug 19, 2025

Purpose

Currently vLLM captures cudagraphs as part of the engine initialization significantly slowing down vLLM startup time. We propose to capture cudagraphs lazily. Instead of performing dummy runs during the engine initialization phase, the idea is to do those runs during the execution. The highest priority of CUDAGraph caching is given by the current runtime shape if not cached already. Otherwise, it is capturing the highest shape that has not been captured yet.
More info in this issue.

Test Plan

Baseline with current approach:

Server:
>> vllm serve meta-llama/Llama-3.1-8B-Instruct

Benchmark:
>> vllm bench serve --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --endpoint /v1/completions --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000

CUDAGraph capturing during model execution:

Server:
>> vllm serve meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"use_cudagraph_delayed_capture": true}'

and 

>> vllm serve meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"use_cudagraph_delayed_capture": true, "cudagraph_mode": "FULL"}'

Benchmark:
>> vllm bench serve --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --endpoint /v1/completions --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000

Test Result

Benchmark results
CUDAGraph during the initialization (baseline):

Initial test run completed. Starting main benchmark run...
Traffic request rate: inf
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:12<00:00, 13.73it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  72.82
Total input tokens:                      215196
Total generated tokens:                  197914
Request throughput (req/s):              13.73
Output token throughput (tok/s):         2717.77
Total Token throughput (tok/s):          5672.85
---------------Time to First Token----------------
Mean TTFT (ms):                          20493.77
Median TTFT (ms):                        18837.80
P99 TTFT (ms):                           48674.65
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          79.60
Median TPOT (ms):                        75.12
P99 TPOT (ms):                           168.10
---------------Inter-token Latency----------------
Mean ITL (ms):                           69.61
Median ITL (ms):                         55.84
P99 ITL (ms):                            172.96
==================================================

Delayed CUDAGraph capture during inference using 1000 prompts from ShareGPT:

Initial test run completed. Starting main benchmark run...
Traffic request rate: inf
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:14<00:00, 13.50it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  74.10
Total input tokens:                      215196
Total generated tokens:                  198133
Request throughput (req/s):              13.50
Output token throughput (tok/s):         2673.83
Total Token throughput (tok/s):          5577.93
---------------Time to First Token----------------
Mean TTFT (ms):                          20431.48
Median TTFT (ms):                        18470.78
P99 TTFT (ms):                           48904.79
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          79.64
Median TPOT (ms):                        75.33
P99 TPOT (ms):                           166.94
---------------Inter-token Latency----------------
Mean ITL (ms):                           69.74
Median ITL (ms):                         57.14
P99 ITL (ms):                            173.86
==================================================

Memory and timing of CUDAGraph captures with model meta-llama/Llama-3.1-8B-Instruct:
Normal CUDAGraph capture during initialization:

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████| 67/67 [00:03<00:00, 18.84it/s]
(EngineCore_0 pid=306836) INFO 08-19 06:59:47 [gpu_model_runner.py:2718] Graph capturing finished in 4 secs, took 0.54 GiB

CUDAGraph capture during inference:

During iniltialization, no CUDAGraph capture -> 0.0 secs

Inference:
(APIServer pid=309263) INFO:     Started server process [309263]
(APIServer pid=309263) INFO:     Waiting for application startup.
(APIServer pid=309263) INFO:     Application startup complete.
(APIServer pid=309263) INFO:     127.0.0.1:48804 - "POST /v1/completions HTTP/1.1" 200 OK
(EngineCore_0 pid=309588) INFO 08-19 07:06:32 [gpu_model_runner.py:2721] Graph capturing for 512 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:32 [gpu_model_runner.py:2721]                         finished in 0.073 secs, took -342.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 1 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.036 secs, took 28.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 504 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 496 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 488 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 480 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.061 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 472 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 464 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 28.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721] Graph capturing for 456 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:33 [gpu_model_runner.py:2721]                         finished in 0.055 secs, took 24.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 448 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 440 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 432 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.066 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 424 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 416 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 408 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 400 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 392 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 28.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 384 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.051 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 376 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.056 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 368 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.052 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 360 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.054 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721] Graph capturing for 352 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:34 [gpu_model_runner.py:2721]                         finished in 0.053 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 344 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.053 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 336 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.052 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 328 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.053 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 320 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.058 secs, took 28.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 312 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.052 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 304 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 24.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 296 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.052 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 288 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.054 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 280 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.053 secs, took 26.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 272 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.054 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 264 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.053 secs, took 24.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 256 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.051 secs, took -10.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721] Graph capturing for 248 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:35 [gpu_model_runner.py:2721]                         finished in 0.051 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 240 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.051 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 232 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.048 secs, took 4.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 224 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 8.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 216 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.055 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 208 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 200 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 192 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.052 secs, took 10.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 184 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 176 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 168 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.050 secs, took 8.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 160 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 152 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.049 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 144 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.048 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721] Graph capturing for 136 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:36 [gpu_model_runner.py:2721]                         finished in 0.049 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 128 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.054 secs, took 12.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 120 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.048 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 112 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.048 secs, took 8.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 104 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 96 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.048 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 88 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.049 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 80 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.063 secs, took 8.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 72 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.049 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 64 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.049 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 56 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.046 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 48 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.056 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 40 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.049 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 32 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.048 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721] Graph capturing for 24 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:37 [gpu_model_runner.py:2721]                         finished in 0.047 secs, took 6.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721] Graph capturing for 16 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721]                         finished in 0.042 secs, took 4.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721] Graph capturing for 8 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721]                         finished in 0.045 secs, took 2.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721] Graph capturing for 4 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721]                         finished in 0.045 secs, took 8.00 MiB
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721] Graph capturing for 2 input tokens
(EngineCore_0 pid=309588) INFO 08-19 07:06:38 [gpu_model_runner.py:2721]                         finished in 0.046 secs, took 6.00 MiB

CUDA Graph capture during inference with FULL CUDA Graph:

(APIServer pid=310788) INFO:     Started server process [310788]
(APIServer pid=310788) INFO:     Waiting for application startup.
(APIServer pid=310788) INFO:     Application startup complete.
(APIServer pid=310788) INFO:     127.0.0.1:38206 - "POST /v1/completions HTTP/1.1" 200 OK
(EngineCore_0 pid=311071) WARNING 08-19 07:11:41 [gpu_model_runner.py:2907] CUDAGraphMode.FULL is not supported with FlashAttentionMetadataBuilder backend (support: AttentionCGSupport.UNIFORM_BATCH); setting cudagraph_mode=FULL_AND_PIECEWISE
(EngineCore_0 pid=311071) INFO 08-19 07:11:41 [gpu_model_runner.py:2721] Graph capturing for 512 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:11:41 [gpu_model_runner.py:2721]                         finished in 0.079 secs, took -342.00 MiB
(EngineCore_0 pid=311071) INFO 08-19 07:11:42 [gpu_model_runner.py:2721] Graph capturing for 1 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:11:42 [gpu_model_runner.py:2721]                         finished in 0.095 secs, took 34.00 MiB
(EngineCore_0 pid=311071) INFO 08-19 07:11:42 [gpu_model_runner.py:2721] Graph capturing for 504 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:11:42 [gpu_model_runner.py:2721]                         finished in 0.061 secs, took 26.00 MiB
(EngineCore_0 pid=311071) INFO 08-19 07:11:42 [gpu_model_runner.py:2721] Graph capturing for 496 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:11:42 [gpu_model_runner.py:2721]                         finished in 0.057 secs, took 24.00 MiB
...
...
(EngineCore_0 pid=311071) INFO 08-19 07:12:03 [gpu_model_runner.py:2721] Graph capturing for 168 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:12:03 [gpu_model_runner.py:2721]                         finished in 1.249 secs, took 14.00 MiB
(EngineCore_0 pid=311071) INFO 08-19 07:12:04 [gpu_model_runner.py:2721] Graph capturing for 160 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:12:04 [gpu_model_runner.py:2721]                         finished in 1.191 secs, took 10.00 MiB
(EngineCore_0 pid=311071) INFO 08-19 07:12:05 [gpu_model_runner.py:2721] Graph capturing for 152 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:12:05 [gpu_model_runner.py:2721]                         finished in 1.138 secs, took 12.00 MiB
(EngineCore_0 pid=311071) INFO 08-19 07:12:06 [gpu_model_runner.py:2721] Graph capturing for 144 input tokens
(EngineCore_0 pid=311071) INFO 08-19 07:12:06 [gpu_model_runner.py:2721]                         finished in 1.101 secs, took 12.00 MiB
...

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Aug 19, 2025
@diegocastanibm diegocastanibm mentioned this pull request Aug 19, 2025
4 tasks


@contextmanager
def freeze_gc(allow_collect: bool = True):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has been moved from the gpu_model_runner.py to utils.py based on this comment

@dosubot
Copy link

dosubot bot commented Aug 20, 2025

Related Documentation

No published documentation to review for changes on this repository.
Write your first living document

How did I do? Any feedback?  Join Discord

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating this! I took a quick look and I have 3 overall pieces of feedback:

  • If a cudagraph exists for the current num_tokens, why are we capturing cudagraphs for sizes we haven't seen? Shouldn't we just wait until we see that size? I know that increases the unpredictability but it also increases savings as we don't capture cudagraphs we rarely use until we do use them.
  • Instead of capturing 0 cudagraphs at the start, should we at least capture the cudagraph for the largest capture size so that future cudagraphs can reuse that memory?
  • Could we build in a mechanism to capture cudagraphs while the server sits idle?
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- list[int]: capture sizes are specified as given."""
use_cudagraph_delayed_capture: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would call this lazy_cudagraph_capture or delay_cudagraph_capture

Copy link
Author

@diegocastanibm diegocastanibm Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. To maintain consistency across our codebase, I recommend retaining the use of the use_ prefix before variable names. This convention provides a clear understanding of the variable's nature and its boolean type, as demonstrated by examples like use_cudagraph or use_inductor.

Comment on lines 289 to 290
to note that this speedup during initialization may result in an
increased Time-To-First-Token (TTFT) for the initial token inferences."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also say it makes TTFT less predictable

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I'll add it.

Comment on lines 2661 to 2662
gc_collect = (
not specific_token_num) if specific_token_num is not None else True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this gc_collect - it's very unclear when it's supposed to be true or not

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gc_collect variable is utilized within the freeze_gc function to enable garbage collection. However, this operation can consume a considerable amount of time and negatively impact the performance of cudagraph capture. In an "eager" mode (during initialization), the "collect" operation is performed only once, which is acceptable. To prevent repeated collection for every single cudagrap during lazy mode captures, we set the gc_collect variable to False when using specific_token_num to capture just one specific cudagraph. In this case, the garbage collector will not collect, and instead, the function will only freeze the specified cudagraph.

not specific_token_num) if specific_token_num is not None else True
set_cudagraph_capturing_enabled(
True) if not specific_token_num else None
with freeze_gc(gc_collect), graph_capture(device=self.device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either call this maybe_freeze_gc or preferably use an ExitStack and call enter_context inside an if statement

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


# gc collector is a time consuming operation
if allow_collect:
gc.collect()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just call gc.collect outside the function, doesn't need to be inside the context manager

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

all_gather_group=get_tp_group()))

if (self.vllm_config.compilation_config.use_cudagraph_delayed_capture
and not self.model_config.enforce_eager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove check for enforce_eager here - instead set use_cudagraph_delayed_capture to False in config init if enforce_eager is set

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

# Check if there are any entries left in _token_compiled_cudagraphs
else:
# Update next_comp to the first item and remove it from the list
next_capture = self.incomplete_cudagraph_capture.pop(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want to capture cgs for an unrelated number of tokens?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment aligns with my response of your first bullet: we do not what number of tokens will have in a future inferences, but we want to keep the unpredictability as short as possible

@diegocastanibm
Copy link
Author

Hi @ProExpertProg ,

Thanks a lot your review and your suggestions. I essentially agree with you in many of them. Individual comments for each bullet below:

Thanks for integrating this! I took a quick look and I have 3 overall pieces of feedback:

* If a cudagraph exists for the current `num_tokens`, why are we capturing cudagraphs for sizes we haven't seen? Shouldn't we just wait until we see that size? I know that increases the unpredictability but it also increases savings as we don't capture cudagraphs we rarely use until we do use them.

As @lionelvillard told me, our primary objectives with this approach are twofold: 1) minimizing the startup time in a cool start scenario, and 2) controlling the unpredictability at the initial stages of the inference process. As you have already noticed, there is a trade-off between predictability and speed. However, considering a hypothetical situation where multiple users share the same node, it becomes apparent that limiting variability to seconds or minutes could yield more substantial benefits over time.

* Instead of capturing 0 cudagraphs at the start, should we at least capture the cudagraph for the largest capture size so that future cudagraphs can reuse that memory?

I agree with this one. Makes sense to me and I'll implement it.

* Could we build in a mechanism to capture cudagraphs while the server sits idle?

Good idea. I propose to do it in a separate PR to keep the code small and clean.

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Aug 21, 2025

@diegocastanibm how about turning the delayed flag into:

class CUDAGraphCaptureStrategy(enum.Enum):
    STARTUP = 0 # Capture all cudagraphs during startup (default)
    DELAYED = 1 # Current approach in the PR, capture one cudagraphs for every request.
    LAZY = 2 # What I proposed, only capture necessary cudagraphs

Also, after we capture the first (largest) cudagraph, we should report the memory usage and warn if we think we might run out of memory.

  • Could we build in a mechanism to capture cudagraphs while the server sits idle?
    Good idea. I propose to do it in a separate PR to keep the code small and clean.

Sounds good!

@ProExpertProg
Copy link
Collaborator

Also I think if you could add tests to existing cudagrapoh dispatcher and wrapper tests that would be great

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @diegocastanibm, thanks for the PR.

I think capturing cuda graphs at runtime affects reliability a lot.
Also for simplicity, I think we should strictly limit it to init time only.

* Change name to use_delay_cudagraph_capture
* Capture largest size during init
* Better description
* Changes in GC
* Enforce_eager and use_delay_cudagraph_capture in config init

Signed-off-by: Diego-Castan <diego.castan@ibm.com>
@diegocastanibm
Copy link
Author

@diegocastanibm how about turning the delayed flag into:

class CUDAGraphCaptureStrategy(enum.Enum):
    STARTUP = 0 # Capture all cudagraphs during startup (default)
    DELAYED = 1 # Current approach in the PR, capture one cudagraphs for every request.
    LAZY = 2 # What I proposed, only capture necessary cudagraphs

Also, after we capture the first (largest) cudagraph, we should report the memory usage and warn if we think we might run out of memory.

  • Could we build in a mechanism to capture cudagraphs while the server sits idle?
    Good idea. I propose to do it in a separate PR to keep the code small and clean.

Sounds good!

Indeed, it is feasible and beneficial to present additional choices. I will implement it soon

@diegocastanibm
Copy link
Author

Also I think if you could add tests to existing cudagrapoh dispatcher and wrapper tests that would be great

Sure! No problem

Copy link
Contributor

@fhl2000 fhl2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @diegocastanibm, Thanks for implementing this. Left a few comments below.

Comment on lines +384 to +388
if next_capture:
logger.debug(
"CUDAgraph in execution model time for %d input tokens",
next_capture)
self.model_runner.capture_model(next_capture)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also should check that once the list is empty, disable capturing cudagraph globally by set_cudagraph_capturing_enabled(False) for extra safety.


# Check if the scheduled token count is in our compiled CUDAgraphs list
# Priority to capture the token count that is in execution
if total_num_scheduled_tokens in self.incomplete_cudagraph_capture:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the padded num tokens for cudagraph. The num_scheduled_tokens can rarely be hit here.

@mergify
Copy link

mergify bot commented Aug 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @diegocastanibm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 26, 2025
@diegocastanibm diegocastanibm changed the title [Core] Lazy/Delayed CUDA graph Aug 26, 2025
@diegocastanibm
Copy link
Author

Based on the feedback from reviewers, we're putting this PR on hold. According to us, the work here represents valuable progress towards a faster and more efficient vLLM initialization, but we need to address the concerns raised before moving forward.
We may revisit and continue this work in the future once we have a clearer path forward. In the meantime, I'll keep this branch available for reference.
Thanks to everyone who provided feedback during the review process.

@bbartels
Copy link
Contributor

bbartels commented Sep 1, 2025

but we need to address the concerns raised before moving forward.

@diegocastanibm what are the biggest concerns in your mind? It's unfortunately that currently you have to either make a trade-off between performance and startup latency.

@diegocastanibm
Copy link
Author

but we need to address the concerns raised before moving forward.

@diegocastanibm what are the biggest concerns in your mind? It's unfortunately that currently you have to either make a trade-off between performance and startup latency.

The biggest concerns on my mind are exactly what you've highlighted - the current trade-off between performance and startup latency is a significant limitation that affects real-world deployment scenarios.
We're actively working on developing an auto-scaler as part of our LLM-d effort that would be ready to properly A/B test the advantages of the lazy CUDAGraph approach. This would allow us to quantify the actual benefits in production workloads rather than relying on theoretical improvements.
Regarding the vLLM maintainer's concern about complexity and maintenance burden - while I understand their perspective, I believe the performance gains could justify the additional complexity, especially for production deployments where startup latency is critical. However, we'd need concrete data from our A/B testing to make a compelling case.

@hmellor
Copy link
Member

hmellor commented Sep 2, 2025

This might be a naive question, but are the CUDA graphs cached anywhere? Could you generate all the CUDA graphs for all the models/batch sizes you intend to use in your deployment and then mount them to your auto scaled machine's storage for vLLM to pick up?

@diegocastanibm
Copy link
Author

This might be a naive question, but are the CUDA graphs cached anywhere? Could you generate all the CUDA graphs for all the models/batch sizes you intend to use in your deployment and then mount them to your auto scaled machine's storage for vLLM to pick up?

That's actually a good question! AFAIU (and please, correct me if I'm wrong), CUDA graphs can't really be pre-generated and cached in the way you're suggesting for a few key reasons:

  1. CUDA graphs are part of NVIDIA's proprietary runtime - they're tightly coupled to the CUDA driver/runtime system and aren't something that can be easily serialized or transferred between deployments.
  2. Hardware dependency - CUDA graphs are highly specific to the actual GPU hardware they're created on. They capture low-level execution details that are tied to the specific accelerator architecture, memory layout, and even the individual GPU instance. Moving them between different GPUs, even of the same model, typically won't work reliably.
  3. Limited pickling success - Some folks have experimented with trying to pickle/serialize CUDA graphs in the past, but it's been pretty hit-or-miss. The graphs contain a lot of internal CUDA runtime state that doesn't survive the serialization process well.

The current approach where vLLM builds the CUDA graphs on-demand during warmup is really the most reliable way to handle this. The graphs need to be created in the actual runtime environment where they'll be executed to ensure compatibility and optimal performance.

@15050188022
Copy link

I encountered an issue when using Lazy CUDA Graph: during multi-TP inference, the first request (i.e., the one that triggers inference) returns garbled output. Have you encountered this issue before? @diegocastanibm

@diegocastanibm
Copy link
Author

I encountered an issue when using Lazy CUDA Graph: during multi-TP inference, the first request (i.e., the one that triggers inference) returns garbled output. Have you encountered this issue before? @diegocastanibm

This PR is on hold and I haven't run it again since Sept 2nd. So many things has changed in vLLM since then. However, when I was working on this PR, I never had an issue and I tested it multiple times and with different models and accelerators.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment