NTTドコモR&Dの技術ブログです。

📖 vLLMのコードを読んでみよう

こんにちは、NTTドコモR&D戦略部の門間です。

この記事では、vLLMのコードを追いつつその中身の動きに迫りたいと思います。

最近、業務やプライベートでLLM関連のいろいろを触っていますが、 OSSのコードリーディングを通じてLLMの推論処理への理解を深めたいというモチベーションです。

🤖 vLLMって?

ChatGPTから始まる昨今の大規模言語モデル (以下LLM) の流行により、 自分の手元でLLMを動かす環境が整備されています。

最も基本的なものはHugging FaceのTransformerでしょう。 その他にも、簡単に使えるllama.cpp、 推論の高速化を狙ったNVIDIAのTensorRT-LLMや、SGLangなどがあります。 vLLMもそのひとつで、いわゆるLLMの「推論エンジン」です。

vLLMはLLMの推論処理を高速化する仕組みを提供しており、 特にスループット (単位時間あたりのトークン生成速度) の向上に注力しています。 速度比較ではより早いとされるものもありますが、 vLLMは情報が多く理解しやすいため、この記事ではvLLMをメインに追っていきます。

高速化の具体的な仕組みは公式の説明や様々な記事で紹介されていますが、 具体的な実装を追った記事は少ないように感じます。

この記事では、次の点を理解するのをとりあえずの目標として読み進めます。

  • オンライン推論とオフライン推論に性能差がないこと
  • スケジューリング処理
  • バッチ推論の実態とメモリブロックの管理

なお、vLLMは開発による変更が激しく、 最新のブランチを追っていると見ていた箇所が次の週には変更されていることがあります。 参照しているコードは10月~12月上旬で確認していますので、 現時点での最新コードと異なることがあります。


📚 前提知識

ゼロからLLMの仕組みを説明すると内容が長くなるため、ある程度の知識を前提に進めます。

次の記事がイメージを掴むために良くまとまっているので、ぜひ読んでみてください。

Attention Is All You Need

言わずと知れた論文です (arXivリンク)。 LLMの推論処理を追うためにはある程度の理解が必要です。

大まかには、前の推論結果を入力として次の推論結果が得られると考えれば良いでしょう。

Paged Attention

vLLMの特徴の1つであるPaged Attentionは、 OSのメモリ管理のようにページング形式のVRAM管理をする仕組みです。 これにより、限られたVRAMの効率的な利用と速度の両立を図っています。

関連してKVキャッシュという言葉も把握しておいてください。

cf. PytorchによるLLMの高速化

Continuous Batching

vLLMはバッチ処理により高速で効率的な処理をしている、と説明されることがよくあります。 ただ、単にバッチ処理といっても、どの箇所で何がバッチ化されているのかは重要です。

推論そのものがバッチ化されているのか、リクエストをバッチ受付しているのか。 vLLMでは両方行っており、バッチ推論にContinuous Batchingという戦略を採用しています。

これはLLMの推論処理を行った結果、次のトークンが得られるたびに、 その時点で追加バッチ処理可能なリクエストを確認し、 可能な限り多くのリクエストを同時に処理するというものです。

そのため、十分に効果が発揮されるのは単一のリクエストではなく、 複数のリクエストを同時に処理する場合です。


📦 vLLMの開発用インストール (Pythonコード開発のみ)

vLLMを動かしながらコードを読むために、最低限の開発用インストールを行います。

動かすだけならpip install vllmで問題ありませんが、コード改変を考慮して公式の Build from source Python-only build (without compilation) をベースに作業します。

なお、GPUはNVIDIA GeForce RTX 4070 Tiを利用しました。 実行環境依存があるかも知れませんので、その点はご留意ください。

Wheelのインストール

まずは、vLLMのC/C++/CUDA部分を除いた部分を動かすためのWheelを取得します。 uvを使ってpythonを管理しているので次のコマンドで実行しました。

uv pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl

が、いきなりここでハマりました。

現時点 (2024/11/18) だと公式には次のように「URLを一意にするため」にバージョン名を固定すると書いてあります。

Note that the wheels are built with Python 3.8 ABI (see PEP 425 for more details about ABI), so they are compatible with Python 3.8 and later. The version string in the wheel file name (1.0.0.dev) is just a placeholder to have a unified URL for the wheels. The actual versions of wheels are contained in the wheel metadata. Although we don’t support Python 3.8 any more (because PyTorch 2.5 dropped support for Python 3.8), the wheels are still built with Python 3.8 ABI to keep the same wheel name as before.

しかし、これだとパッケージの依存性の解決が正常に行われず、真面目に依存解決をしてくれるパッケージマネージャを使うと 次のようなエラーが出てしまいます。生のpipで行うと成功しますが、これはpipの実質的なバグとして報告されています。

error: Failed to install: vllm-1.0.0.dev0-cp38-abi3-manylinux1_x86_64.whl (vllm==1.0.0.dev0 (from https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl))
  Caused by: Wheel version does not match filename: 0.6.3.dev155+gf3a507f1.d20241010 != 1.0.0.dev0

そのため、公式チュートリアルに従うには通常のpip installを使用する必要があります。

uv venvで作成したvenvにpipがない場合、システムpythonにインストールするか、venvにpipを追加して(uv add pipなど) pip installを実行するなどの何らかの対処が必要です。今回はvenvにpipを追加して実行しました。

リポジトリのクローン

公式に従い、次のコマンドを実行します。

git clone https://github.com/vllm-project/vllm.git
cd vllm
python python_only_dev.py

実行すると、仮想環境下に展開されたパッケージのバイナリをcloneしたvllmディレクトリにコピーし、 cloneされたディレクトリを仮想環境にシンボリックリンクします。

また、スクリプトを-q付きで利用するとコピーされたものをもとに戻せます。

起動確認

これで開発環境が整いました。

では起動してみましょう。 起動コマンドは現時点 (2024/11/18) でQuickstartに記載のある以下を利用します。

vllm serve Qwen/Qwen2.5-1.5B-Instruct

起動しました。

INFO 11-18 20:44:07 api_server.py:592] vLLM API server version 0.6.4.post2.dev25+g01aae1cc
INFO 11-18 20:44:07 api_server.py:593] args: Namespace(subparser='serve', model_tag='Qwen/Qwen2.5-1.5B-Instruct', 
...(中略)...
INFO:     Started server process [169392]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

OpenAI互換のChat Completion APIも問題なく動いています。

$ curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-1.5B-Instruct",
"prompt": "San Francisco is a",
"max_tokens": 7,
"temperature": 0
}'

{"id":"cmpl-f80bb25d622b4af5bc41c5e5b425e16a","object":"text_completion","created":1731931161,"model":"Qwen/Qwen2.5-1.5B-Instruct","choices":[{"index":0,"text":" city in the state of California,","logprobs":null,"finish_reason":"length","stop_reason":null,"prompt_logprobs":null}],"usage":{"prompt_tokens":4,"total_tokens":11,"completion_tokens":7,"prompt_tokens_details":null}}

Pythonコードの改変

最後にリポジトリのコードを修正することで実際コードの改変が適用されるのかを確認します。

vllmコマンドが何をしているのかをみてみると、 次のコードの通りfrom vllm.scripts import mainでimportされるmain()を実行しているに過ぎません。

# -*- coding: utf-8 -*-
import re
import sys
from vllm.scripts import main
if __name__ == '__main__':
    sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
    sys.exit(main())

変更の確認のため、main()の中にprint()を追加してみます。

diff --git a/vllm/scripts.py b/vllm/scripts.py
index a51c21cf..eeffebb3 100644
--- a/vllm/scripts.py
+++ b/vllm/scripts.py
@@ -141,6 +141,7 @@ def env_setup():


 def main():
+    print("Modified !!!")
     env_setup()

     parser = FlexibleArgumentParser(description="vLLM CLI")

確かに表示されました。問題なく動作していますね。

$ vllm serve Qwen/Qwen2.5-1.5B-Instruct

Modified !!!
INFO 11-18 21:09:08 api_server.py:592] vLLM API server version 0.6.4.post2.dev25+g01aae1cc
...()...

このままコード修正を続ければ、vLLMへの改変が容易に可能です。

デバッガを使ったOSSのコードリーディングのススメ

世の皆様はOSSのコードリーディングをどのように行っているでしょうか。

もちろん、具体的に追いたい処理があれば、関連しそうなワードでgrepして周辺をあたるのが早いでしょう。 それほど複雑でなければ、エントリーポイントを探してそこから辿っていくのも有効です。

ただ、個人的にはコードを読むときはある程度実際に動かしてみるのが理解しやすく感じます。 幸いなことにPythonはVSCodeで簡単にデバッグ実行できるので、それを利用して今回は読み進めました。 (ライブラリの中までデバッグ実行するだけなら、justMyCodefalseに設定するだけでも可能ですが、 今回は今後の修正も見越して実際にコードが改変できる状態にしています)

ということで、デバッガを使って要所要所でブレークしながら進めるというスタイルが、 ランタイムに実際に変数に格納される中身も見ることができ理解を進めやすいと思っていますのでおすすめしておきます。


🧩 vLLMのソフトウェアアーキテクチャ

では、本題となるvLLMのコードを追ってみましょう。

意外と公式のドキュメントが充実しているので、まずはそちらを参照してみることをおすすめします。 ドキュメントにはstable版とlatest版がありますので、自分が実行する、 または解析するコードに応じて参照先を決定するとよいでしょう。

特にArchitecture Overviewには、 vLLMを利用する際のエントリーポイントから、中核となるLLMEngineクラスの動作、 及びその内部の各種クラス階層までわかりやすく記載されています。

改めて目的は、次の3つの理解です。 まずは1つ目、アーキテクチャをコードを追いながら理解しつつ、オンライン推論とオフライン推論に性能差がないことを確認していきます。

  • オンライン推論とオフライン推論に性能差がないこと 👈
  • スケジューリング処理
  • バッチ推論の実態とメモリブロックの管理

なお、オンライン推論とはOpenAI互換APIにリクエストして推論を返す形式を指しており、 オフライン推論とはvllm.LLMクラスまたはLLMEngineクラスを直接利用して推論する形式を指します。

オンライン推論 : FastAPIサーバの立ち上げとEngineClientの生成

はじめに大まかな構造を図示します。

vLLMでサーバを立ち上げると、デフォルト設定では次の図のような構成で処理されることになります。

vLLMのOpenAI互換サーバ

内部ではエンジンクライアント (EngineClient) が推論エンジン (LLMEngine) と通信する構成になっています。 (図では、エンジンクライアントは以降に触れる具体的なインスタンスとしてMQLLMEngineClientと示されています)

EngineClientは非同期にAPIリクエストを受け付け、適宜LLMEngineと通信することで推論します。

次に、エントリーポイントからの処理の流れを示します。

---
config:
  theme: neutral
---
graph TD
A["vllm.entrypoints.openai.api_server.run_server()"]

subgraph EngineClientの生成
B["vllm.entrypoints.openai.api_server.build_async_engine_client()"]
C["vllm.entrypoints.openai.api_server.build_async_engine_client_from_engine_args()"]
B --> C
end

subgraph FastAPIアプリの構築
P["vllm.entrypoints.openai.api_server.build_app()"]
end

subgraph FastAPIの状態の初期化
N["vllm.entrypoints.openai.api_server.init_app_state()"]
end

subgraph サーバの起動
M["vllm.entrypoints.launcher.serve_http()"]
end

A --> EngineClientの生成
EngineClientの生成 --> FastAPIアプリの構築
FastAPIアプリの構築 --> FastAPIの状態の初期化
FastAPIの状態の初期化 --> サーバの起動

style EngineClientの生成 rx:10,ry:10
style FastAPIアプリの構築 rx:10,ry:10
style FastAPIの状態の初期化 rx:10,ry:10
style サーバの起動 rx:10,ry:10

オンライン推論処理では、API互換のサーバを立ち上げるためにvllm.entrypoints.openai.api_server.run_server()をエントリーポイントとしています。

まずはこちらからみていきましょう。 この関数はOpenAI API互換のサーバを立ち上げ、uvicornを利用してサーバを起動します。

処理の流れは次の通りです。

  1. EngineClientの生成
  2. FastAPIアプリケーションの構築
  3. FastAPIの状態の初期化
  4. サーバの起動

1. EngineClientの生成

vllm.entrypoints.openai.api_server.build_async_engine_client()EngineClientが生成されます。 このクライアントは、推論エンジンとの通信するためのクラスであり、推論エンジンの生成と通信の設定をします。

設定に基づいて具体的なインスタンスとしてAsyncLLMEngineまたはMQLLMEngineClientを生成します。 現時点 (2024/11/25時点) では、デフォルト設定だと可能な限りMQLLMEngineClientが生成されます。

どちらのエンジンクライアントでも同じIFをサポートしていますので、以降はMQLLMEngineClientを前提として記述します。 詳細は後述します。

その後、マルチプロセスで推論エンジンを立ち上げ、MQLLMEngineClientのインスタンスの生成とセットアップを行います。

2. FastAPIアプリケーションの構築

vllm.entrypoints.openai.api_server.build_app()でFastAPIアプリケーションが構築されます。 このアプリケーションは、OpenAI API互換のエンドポイントを提供します。

FastAPIのAPIRoutervllm/entrypoints/openai/api_server.pyでグローバルに定義されており、 推論に直接関わる部分として次の3つのエンドポイントが定義されています。

@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
                                raw_request: Request):
  ...

@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
  ...

@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  ...

現在一般的なのはおそらくChat Completion APIであるため、以降はcreate_chat_completion()を見ていきます。

3. FastAPIの状態の初期化

vllm.entrypoints.openai.api_server.init_app_state()でFastAPIアプリケーションの状態を初期化します。 ここで、FastAPIのstateに生成した推論エンジンを紐づけます。

4. サーバの起動

vllm.entrypoints.launcher.serve_http()でFastAPIアプリケーションをHTTPサーバとして立ち上げ、 vLLMのオンライン推論を開始できます。

オンライン推論 : OpenAI互換サーバでのリクエスト受付

OpenAI互換サーバでのリクエスト受付は、 vllm.entrypoints.openai.api_server.create_chat_completion()で行われます。

処理の流れは次の通りです。

  1. リクエストの前処理
  2. エンジンへのリクエスト送信

1. リクエストの前処理

self._preprocess_chat()でリクエストの前処理を行います。 細かな処理が色々と行われますが、処理のポイントはLLMに入力するためのトークン列の生成です。

メソッド内部ではconversationrequest_promptsengine_promptsの3つの値を取得します。

conversationはChat Completion APIでよく指定される次の形式の辞書のリストです。

conversation = [
  {"role": "system", "content": "You are a helpful assistant."},
  {"role": "user", "content": "Who won the world series in 2020?"}
]

request_promptsはモデルで指定されたchat_templateフォーマットを適用したものです。 代表的なところでは次のような形式に変換されることが多いでしょう。

request_prompts = '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWho won the world series in 2020?<|im_end|>\n<|im_start|>assistant\n'

最後に、engine_promptsは推論への直接の入力となるトークナイズされた文字列です。 上記のrequest_promptsQwen/Qwen2.5-1.5B-Instructへの入力とした場合、例えば次のような値となります。

engine_prompts = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 15191, 2765, 279, 1879, 4013, 304, 220, 17, 15, 17, 15, 30, 151645, 198, 151644, 77091, 198]

2. エンジンへのリクエスト送信

self.engine_client.generate()でエンジンクライアントを介して、推論エンジンに対して推論リクエストを送信します。

内部では、まずリクエストに対応するasyncioのキューを作成し、 リクエストIDをキーとしてリクエストキューをクライアントに登録します。

次に、生成指示をエンジンに送信します。今回はEngineClientMQLLMEngineClientとする前提であるため、RPCProcessRequestにくるまれて、処理プロセスへ送信されます。

最後に、送信されたリクエストはMQLLMEnginerun_engine_loop()で受け取られ、推論エンジンの処理が実施されます。

オンライン推論 : MQLLMEngineの生成とLLMEngineの起動

次に、クライアントとの通信を仲介して推論エンジン本体を制御するラッパーであるMQLLMEngineの生成と、 推論エンジンの起動の流れを追ってみましょう。

まずはざっくりコールスタックです。

---
config:
  theme: neutral
---
graph TD
S(["EngineClientの生成"])
S --> A["vllm.engine.multiprocessing.engine.run_mp_engine()"]
A --> B["vllm.engine.multiprocessing.engine.MQLLMEngine.from_engine_args()"]

subgraph LLMEngineの初期化
D["MQLLMEngine.init()"]
end

subgraph MQLLMEngineの起動
E["engine.start()"]
E --> F["self.run_startup_loop()"]
E --> G["self.run_engine_loop()"]
end

subgraph エンジンループ
    subgraph リクエストの処理
    H["self.handle_new_input()"]
    H --> I["self.handle_process_request()"]
    I --> J["self.engine.add_request()"]
    end

    subgraph 推論の実行
    K["self.engine_step()"]
    K --> L["self.engine.step()"]
    end

    リクエストの処理 --> 推論の実行
end

B --> LLMEngineの初期化
LLMEngineの初期化 --> MQLLMEngineの起動
G --> エンジンループ

style LLMEngineの初期化 rx:10,ry:10
style MQLLMEngineの起動 rx:10,ry:10
style エンジンループ rx:10,ry:10
style リクエストの処理 rx:10,ry:10
style 推論の実行 rx:10,ry:10

処理概要は次の通りです。

  1. MQLLMEngineの初期化
  2. MQLLMEngineの起動
  3. リクエストの処理

1. MQLLMEngineの初期化

vllm.engine.multiprocessing.engine.run_mp_engine()MQLLMEngineを生成し、その後エンジンを実行します。

ややこしいのですが、ここで作成されるのはMQLLMEngineでありMQLLMEngineClientではありません。 MQLLMEngineの処理の実態は、内部で生成し保持されるLLMEngineインスタンスです。

vllm.engine.multiprocessing.engine.MQLLMEngine.from_engine_args()MQLLMEngineが生成され、 その内部でLLMEngineが初期化されます。

このLLMEngineが推論処理の本体であり、vLLMの処理のキモです。

2. MQLLMEngineの起動

engine.start()MQLLMEngineが起動します。

具体的には内部で2つのループを実行します。

1つ目がself.run_startup_loop()で、EngineEngineClientへのIPC通信を確立します。 2つ目がself.run_engine_loop()で、Engine処理のメインループを実行します。

これが処理の中心となるループで、エンジンの状態を監視し、 クライアントからのリクエストを受け付けます。

もう少し詳しくself.run_engine_loop()の処理を見ていきましょう。

まず基本となるループ処理として、エンジンに処理中のリクエストがない場合はクライアントからのリクエストを待ち受けます。

def run_engine_loop(self):
    """Core busy loop of the LLMEngine."""

    while True:
        if not self.engine.has_unfinished_requests():
            # Poll until there is work to do.
            while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
                # When there's no work, check on engine health and send
                # health status back to client
                self._health_check()
                self.engine.do_log_stats()
                logger.debug("Waiting for new requests in engine loop.")

        # Handle any input from the client.
        self.handle_new_input()

        # Engine step.
        request_outputs = self.engine_step()

        # Send request outputs (if async, done in engine_step callback).
        if not self.use_async_sockets:
            self._send_outputs(request_outputs)

新たなリクエストが追加された場合には、次の呼び出しで新規リクエストをキューイングします。

# Handle any input from the client.
self.handle_new_input()

内部のLLMEngineの推論をstep()で1つ進めることで、キューに入ったリクエストの次のトークンを生成します。 step()engine_step()の内部で呼び出されます。

# Engine step.
request_outputs = self.engine_step()

# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
    self._send_outputs(request_outputs)

3. リクエストの処理

self.handle_new_input()でクライアントからの新しいリクエストを処理します。

内部では、self._handle_process_request()でエラーハンドリングを行い、 self.engine.add_request()を呼び出してリクエストをエンジンのリクエストキューに追加します。

その後、self.engine_step()を呼び出してエンジンの推論処理を1ステップ進めます。

LLMEngine

ではいよいよvLLMの処理の本体であるLLMEngineを見ていきましょう。

まずオーバービューとして、LLMEngineは内部的にモデルのロードやメモリの初期化を行い、 step()メソッドで推論処理を行う仕様になっています。

なお、前述のオンライン推論処理では、最終的にLLMEngineに対して handle_new_input()でのリクエストの追加と、engine_step()での推論処理の実行が行われていました。

LLMEngineの初期化

LLMEngineの初期化
LLMEngineの初期化

初期化処理は図のように行われ、LLMEngineの内部には幾つかのクラスが生成されます。 ざっくりコールスタックも次に示しておきましょう。

---
config:
  theme: neutral
---
graph TD
A["LLMEngine.init()"] 

subgraph Executorの初期化
C["executor_class.init()"]
C --> D["self.model_executor.init_executor()"]
D --> モデルロード

    subgraph モデルロード
    E["self.create_worker()"]
    E --> I["self.driver_worker.load_model()"]
    I --> J["self.model_runner.load_model()"]
    end
end

subgraph KVキャッシュの初期化
O["self.initialize_kv_caches()"]

    subgraph 利用可能メモリの決定
    P["self.model_executor.determine_num_available_blocks()"]
    P --> Q["self.driver_worker.determine_num_available_blocks()"]
    end

    subgraph VRAMの確保
    direction TB
    R["self.model_executor.initialize_cache()"]

        subgraph CUDAGraphのキャプチャ
        U["self.warm_up_model()"]
        U --> V["self.model_runner.capture_model()"]
        end

    R --> S["self.driver_worker.initialize_cache()"]
    S --> T["self.init_cache_engine()"]
    T --> CUDAGraphのキャプチャ
    end

    O --> 利用可能メモリの決定
    利用可能メモリの決定 --> VRAMの確保
end

subgraph スケジューラの初期化
W["vllm.core.scheduler.Scheduler.init()"]
W --> X["self.scheduler.init_scheduler()"]
end

A --> Executorの初期化
Executorの初期化 --> KVキャッシュの初期化
KVキャッシュの初期化 --> スケジューラの初期化

style Executorの初期化 rx:10,ry:10
style モデルロード rx:10,ry:10
style KVキャッシュの初期化 rx:10,ry:10
style 利用可能メモリの決定 rx:10,ry:10
style VRAMの確保 rx:10,ry:10
style CUDAGraphのキャプチャ rx:10,ry:10
style スケジューラの初期化 rx:10,ry:10

処理の流れは次の通りです。

  1. Executorの初期化
  2. KVキャッシュの初期化
  3. スケジューラの初期化
1. Executorの初期化

Executorは実行状態を管理するクラスで、内部でWorkerModelRunnerのクラスが生成され、 それぞれの役割に応じた処理を行います。

executor_class.__init__()Executorの初期化を行います。 executor_classはデフォルトでかつGPUが利用可能な場合はGPUExecutorに解決されます。

Executorは初期化中にWorkerを生成します。 GPUExecutorの場合、WorkerひとつにつきGPUひとつを管理するモデルとなっています。

呼び出しの階層が非常に深いので詳細は省略しますが、最終的にはモデルのロード処理が行われます。 デフォルトではtorchによる読み込み処理が行われます。

2. KVキャッシュの初期化

LLMEngine._initialize_kv_caches()でKVキャッシュの初期化を行います。 ここでvLLMの工夫点のひとつであるPaged Attentionを利用するため、メモリをブロックに分割して管理しています。

具体的な割り当てサイズの計算のため、determine_num_available_blocks()でOOMにならないギリギリのメモリ割り当て数を計算します。 実施時点でのメモリ使用量とダミーデータを使った推論の実行後のメモリ使用量を計算し、 搭載されたVRAMに利用率を乗じた値から、ピーク利用メモリを差し引いた値を利用可能メモリとします。 利用率はデフォルトでは0.9 (空きの90%をKVキャッシュとして利用) とされています。

GPUのVRAM上のモデルとキャッシュ
GPUのVRAM上のモデルとキャッシュ

なお、同時に推論できるリクエスト数はここで確保されるKVキャッシュの大きさに依存します。 ただ、当然LLMへのリクエストによってプロンプトのサイズは異なりますし、生成される系列の長さも一定ではありません。 そこで、ブロックに分けたメモリ領域を動的に割り当てることでキャッシュを効率的に利用し、 同時処理数を最大化しつつ高速な推論を可能とするのがvLLMということですね。

キャッシュサイズは、CacheEngine.get_cache_block_size()で取得されます。

key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)

return dtype_size * total

その値は、具体的には上記の計算式の結果であるtotalにデータ型のサイズを掛けた値となります。 なお、block_sizeはデフォルトでは16が指定されており、 Size of a cache block in number of tokens.とのことから、 デフォルトでは1ブロックあたり16トークンをキャッシュします。

具体的な挙動は記事の最後で確認しています。

最後にvllm.worker.model_runner.ModelRunner._warm_up_model()でモデルのウォームアップを行います。 ウォームアップとは、CUDA Graphsという機能を有効化することを指しています。 CUDA GraphsはCUDAの計算を事前に実行してキャプチャすることで、 再度呼び出し可能な最適化された処理を行うことができる機能のようです。

cf. https://www.mattari-benkyo-note.com/2021/10/23/pytorch-cuda-graphs/

3. スケジューラの初期化

vllm.core.scheduler.Scheduler.__init__()でスケジューラの初期化を行います。

スケジューラは、推論処理のスケジューリングを行うクラスであり、モデルの推論処理を管理します。

分散推論を行う場合、スケジューラはその並列数だけ生成されますが、 今回は簡単のため生成されるのは1つだけとします。

推論の要求をオンラインで受け付ける場合、実行時にならないと実行される処理の量がわかりません。 スケジューラはこれを解決するために、どのリクエストを処理中にするかを決定します。

リクエストには実行中 (runnign)、待機中 (waiting)、スワップ中 (swapped) の状態があり、 それぞれのキューで管理されています。

実行中は次の推論で実行されるものであり、待機中はまだ推論が実行されていないものです。

スワップ中は主にVRAMのサイズが足りないため、一時的にメモリに退避されているものです。 最大限並列する推論を試みますが、生成系列が長い場合は当初想定していたVRAM不足になることがあります。 そういったときに、一時的にDRAMサイドにメモリを退避することでVRAMを解放し他のリクエストを処理できます。

リクエストのスワップ
リクエストのスワップ

詳細は後述します。

LLMEngineへのリクエストの追加

推論リクエストの追加
推論リクエストの追加

LLMEngine.add_request()でエンジンのリクエストキューにリクエストを追加します。 この時点では待機中のリクエストに追加されるだけで、処理はまだ行われません。

様々な形式でのリクエストが受け付けられますが、 このタイミングで前処理をかけ、最終的にトークナイザーでトークン化された数値のリストを エンジンのリクエストとなるように整形します。

LLMEngineの推論処理

ようやく推論処理の本体であるLLMEngine.step()までたどり着きました。

これを実行することで、エンジンの推論処理を1ステップ進める、つまり次の1トークンを生成します。 たった1トークン得るところまででも、処理の最初から追っていくのは長い道のりですね。

ざっくりコールスタックは次のようになります。

---
config:
  theme: neutral
---
graph TD
A["LLMEngine.step()"]

    subgraph スケジューリング
    B["self.scheduler[virtual_engine].schedule()"]
    B --> C["self.schedule()"]
    end

    subgraph モデルの実行
    E["self.model_executor.execute_model()"]
    E --> F["self.driver_worker.execute_model(execute_model_req)"]
    F --> H["self.model_runner.execute_model()"]
    H --> I["model_executable()"]
    end

A --> スケジューリング
スケジューリング --> モデルの実行

style スケジューリング rx:10,ry:10
style モデルの実行 rx:10,ry:10

処理の流れは次の通りです。

  1. スケジューリング
  2. モデルの実行
1. スケジューリング
LLMEngine.step() による推論の実施
LLMEngine.step() による推論の実施

LLMEngine.step()の内部では、まずスケジューラのschedule()メソッドが呼び出されます。

具体的なスケジューリング処理とタイムラインの例は別途後述しますが、 大雑把には新規リクエストがあればそれを優先して処理し、なくなり次第可能な限り多くのリクエストを同時に処理できるようスケジュールします。

基本的には、待機中のリクエストは順次実行中に移行され、メモリが許す限りは全て同時に処理されますが、 メモリが逼迫している場合は一時的にスワップ状態にされます。

2. モデルの実行
モデルの実行
モデルの実行

モデルの実行は、LLMEngine.model_executor.execute_model()で行われます。 上の図のように、最終的にはtorchのモジュールの実行になります。

ただし、高速化のために内部処理にはCUDA Graphsを利用しています。

オフライン推論 : LLMEngineの直接実行

オンライン推論処理では、LLMEngineのインスタンスを生成し、 それを介してリクエストの追加と推論処理の実行が行われていたことを確認しました。

オフライン推論処理ではどうでしょうか?

まずは公式のLLM Engine exampleの処理を見てみます。

cf. https://docs.vllm.ai/en/latest/getting_started/examples/llm_engine_example.html

結論から言うと公式のLLM Engine exampleの基本処理は、 MQLLMEnginerun_engine_loop()処理と同様です。

次のprocess_request()と前述のrun_engine_loop()とを比較するとほとんど同じです。 このExampleがLLMEngineの動かし方の例なので当然と言えば当然ですね。

def process_requests(engine: LLMEngine,
                    test_prompts: List[Tuple[str, SamplingParams]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params = test_prompts.pop(0)
            engine.add_request(str(request_id), prompt, sampling_params)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)

オフライン推論 : vllm.LLMクラスの利用

オフライン推論がオンライン推論と同じ結果を得られるかどうか、本題はvllm.LLMクラスの利用をする場合です。 vLLMのオフライン推論処理を実施する際は、次のようなサンプルコードが紹介されがちです。

このコードが結果としてオンライン推論時と同じ処理になるのであれば、 概ね提供方法の差でしか無いということができます。

cf. https://docs.vllm.ai/en/v0.6.1.post2/getting_started/examples/offline_inference.html

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

このコードはLLMクラスを介してvLLMの推論機能を利用していますが、内部的にはLLMEngineを利用して推論処理が行われています。

また、プロンプトが1つであろうとリストで与えられようと、 それぞれのプロンプトはLLMEngine.add_request()を介してエンジンにリクエストされます。

class LLM:
    ...

  def _add_request(
      self,
      prompt: PromptType,
      params: Union[SamplingParams, PoolingParams],
      lora_request: Optional[LoRARequest] = None,
      prompt_adapter_request: Optional[PromptAdapterRequest] = None,
      priority: int = 0,
  ) -> None:
      request_id = str(next(self.request_counter))
      self.llm_engine.add_request(
          request_id,
          prompt,
          params,
          lora_request=lora_request,
          prompt_adapter_request=prompt_adapter_request,
          priority=priority,
      )

したがって、オフライン推論・オンライン推論ともに、 最終的にはLLMEngine.add_request()が呼び出されます。

LLMクラスを利用するかAPIを経由してリクエストを送信するかは、 LLMEngineへのリクエストの送信方法の違いだけであり、 コアとなるエンジン自体の挙動は変わりません。

長かったですね。ようやく当初の目的までたどり着きました。 vLLMのオンライン推論とオフライン推論に性能の違いは無いとわかりました。 でももう少しだけ続きます。

vLLMのバッチリクエスト処理

vLLMにはOpenAIのBatch API相当のjsonlファイルを受け取って推論をするrun_batch.pyが用意されています。 これも結局はLLMEngine(厳密にはAsyncLLMEngine)を利用して推論処理を行っています。

最終的にはvllm.entrypoints.openai.serving_chat.OpenAIServingChatcreate_chat_completion()を、 jsonlに含まれる有効なリクエスト数分だけ呼び出しているにすぎません。

なお、若干ややこしいところでAsyncLLMEngine自体はEngineClientの一種であり、 エンジンそのものは、LLMEngineを継承する_AsyncLLMEngineです。

class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for :class:`LLMEngine`."""
  ...

  _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
  ...

AsyncLLMEngineは、内部に_AsyncLLMEngineのインスタンスを保持しています。 MQLLMEngineClientからLLMEngine.add_request()が呼ばれるのと同様に、 _AsyncLLMEngine.add_request_async()AsyncLLMEngineから呼ばれることで処理がキューイングされます。

create_chat_completion()は抽象化されたEngineClient.generate()を呼び出すだけですので、 このrun_batchにおいても、コアとなるエンジンの挙動は変わりません。


🕒 リクエストのスケジューリング

さて、次と言うには1つ目が長くて重すぎるところはありましたが、 ようやく目的の2つ目に移ります。

  • オンライン推論とオフライン推論に性能差がないこと
  • スケジューリング処理 👈
  • バッチ推論の実態とメモリブロックの管理

vLLMのスケジューリング処理ですね。

vLLMは効率的なバッチ推論のために、今回どのリクエストを処理すべきかを 毎回の推論ステップごとにスケジューリングを行っています。

まず、vLLMの処理ステージは次の2つの状態で管理されます。

class SequenceStage(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()

PREFILLは入力プロンプトの評価であり、GPUの処理負荷が高い処理です。 DECODEは次の生成トークンの評価であり、GPUのメモリ負荷が高い処理です。

デフォルトでは、可能な限り多くのリクエストを同時に処理するため、 PREFILL対象のリクエストをスケジュールすることを優先します。

具体的なデフォルトのスケジューリングロジックは以下にざっくり記載しますが、 LLMEngineにリクエストが追加されるたびに、各リクエストがどのように処理されるのかを次の図で示します。

推論処理のタイムライン
推論処理のタイムライン

同時に処理できるリクエストはmax_num_seqsになり、次に説明しているバッチ推論の実態とも関わってきます。

具体的なスケジューリングロジック

まず、スワップ中のシーケンスグループ (≒リクエスト) が存在しない場合は、 PREFILL対象にできる処理を待機中キューから次の実行対象にスケジュールすることを検討します。

無条件に追加するわけではなく、前回受付時間よりも一定時間経っている (≒latencyの待ち上限に引っかからない※) かつ、 待機中 (SequenceStatus = WAITING) のシーケンスグループがある限り、今回のステップで処理可能な対象を候補として取り上げます。

※ 次の条件を満たす場合。

passed_delay = (
                (now - earliest_arrival_time) >
                (self.scheduler_config.delay_factor * self.last_prompt_latency)
                or not self.running)

追加の際はプロンプト長も考慮されます。

プロンプト長がモデルの入力サイズを上回っておらず、VRAMへの割り当てが可能である場合は、 最終的に実行中キューに追加 (SequenceStatus = RUNNINGへ変更) します。

プロンプト長を超過していたり、超過していなくてもメモリサイズ的に推論不能と 判断された対象シーケンスはFINISHED_IGNOREDでマーキングされ推論対象から外されます。

LoRA適用をする場合、同時に読み込めるLoRA数上限を超えていた場合にも、一旦待機中キューに戻され次のステップを待機します。

PREFILL対象の処理がここまででスケジュールされた場合、今回のステップではデコード対象はスケジュールされません。 そうでない場合はデコード対象の処理をスケジュールします。 そのため、デフォルトのスケジューリングロジックでは、runningキューにはPREFILLDECODEのどちらかのみが含まれます。

また、デコード対象をスケジュールする際に余力があればスワップ中の処理を再開させます。

なお、実行状態はEnumで次のように定義されており、2より大きい値のstateになったときには終了とみなされます。

class SequenceStatus(enum.IntEnum):
    """Status of a sequence."""
    WAITING = 0
    RUNNING = 1
    SWAPPED = 2
    # Note: anything after SWAPPED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        return status > SequenceStatus.SWAPPED

Chunked Prefill

Chunked Prefill (※) が設定されている場合は多少異なるスケジュール戦略となります。 今回はこちらの動作の詳細は追うのを割愛し概要に留めます。

※ cf. https://docs.vllm.ai/en/latest/models/performance.html#chunked-prefill

Chunked Prefillは、大まかにはPREFILLよりもDECODEを優先する設定です。 冒頭に示した図の通り、デフォルトのスケジュール処理ではPREFILL対象の処理がある限りそれをさばくのを優先し、 DECODE対象の処理はPREFILLがなくなるまで次のトークンを生成しません。 これは、リクエストに対して TTFT (Time To First Token) を最小化する戦略です。

しかし、DECODE対象の処理がを待機するということは、そのリクエストにおいて、ITL (Inter Token Latency) が増加することを意味します。

Chunked Prefillはその名の通り、PREFILL処理をチャンクに分割して処理することで、DECODE対象の処理と一緒にバッチ処理する機能です。

状況にもよりますが、公式記載いわく「設定値によってスループットは低下するおそれがある」とのこと。 どの程度かは実際に試してみる必要がありそうですが、ある程度TTFTかITLのトレードオフになりそうです。


⚡ バッチ推論の実態とメモリブロックの管理

2つ目は程よい長さでしたね。 さて、次が3つ目、最後のバッチ推論の実態とメモリブロックの管理 についてです。

  • オンライン推論とオフライン推論に性能差がないこと
  • スケジューリング処理
  • バッチ推論の実態とメモリブロックの管理 👈

バッチ推論とは言うけれど具体的にその処理はどうやって起動しているのか、 工夫点としてあげられているPaged Attentionは実際に実行するとどのように見えるのか。

まず、バッチ処理の本体は、vLLMの現状の実装ではCUDA Graphsのreplayに相当します。

Currently cuda graph is only supported by the decode phase.

コードのこの記載通りで、このときmodel_executableに取得されるのはCUDA Graphです。

        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
        # TODO: We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
        else:
            model_executable = self.model

このコードでは、graph_batch_sizeにバッチサイズを取得し、 graph_runnersに登録されたmodel_executableを取得しています。

CUDA Graphsのキャプチャ

あらかじめキャプチャされているrunnerは次のようになっており、 256バッチまでのCUDA Graphsが登録されています。

キャプチャ済みのCUDA Graphs
キャプチャ済みのCUDA Graphs

実行時にどのくらいのバッチサイズを想定してCUDA Graphsをキャプチャするかは batch_size_capture_listに計算されており、上限はself.max_batchsize_to_captureです。

        graph_batch_size = self.max_batchsize_to_capture
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

_BATCH_SIZES_TO_CAPTUREは次のコードで定義されています。

_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]

このままだと最大で8192まで登録されるように見えますが、実際には注釈に記載通り、 self.max_batchsize_to_capture_get_graph_batch_size()で最大値が決定されます。

        self.max_batchsize_to_capture = _get_max_graph_batch_size(
            self.scheduler_config.max_num_seqs)

ここで、_get_max_graph_batch_size()の実装は以下であり、self.scheduler_config.max_num_seqsは、 LLMEngineの設定として与えられるmax_num_seqsに由来します。 デフォルトでは256となっているため、画像のような256までのCUDA Graphsが登録されることになります。

cf. https://docs.vllm.ai/en/stable/models/engine_args.html

def _get_max_graph_batch_size(max_num_seqs: int) -> int:
    """
    max_num_seqs: Maximum number of sequences in a batch.
    _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.

    pad the max_num_seqs if necessary by calling _get_graph_batch_size,
    which will deal with some edge cases like 1, 2, 4.

    if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
    if not, it means the padded size is larger than the largest size in
    _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
    """
    padded_size = _get_graph_batch_size(max_num_seqs)
    if padded_size in _BATCH_SIZES_TO_CAPTURE:
        return padded_size
    assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
    return _BATCH_SIZES_TO_CAPTURE[-1]

実際のバッチ数が完全に一致しない場合

バッチ数が完全に一致しない場合は、可能なバッチサイズのうちで最も近い値を取得します。 例えばrunningなリクエストが5つある状態で、推論直前のバッチサイズをみてみると次のようになっています。

実行直前のバッチサイズ確認
実行直前のバッチサイズ確認

graph_batch_sizeが8になっているので、バッチを実行可能でかつ最も小さいサイズのCUDA Graphsが起動されるようです。 実際の処理では、既定のバッチサイズに満たない分をpaddingとして計算し、そのサイズ分だけ0で埋める処理が行われます。

これによって、必ずしもキャプチャされたCUDA Graphsのバッチサイズと一致しなくても、paddingを行うことでCUDA Graphsを利用できます。

cuda_graph_pad_size = self._get_cuda_graph_pad_size(
    num_seqs=len(seq_lens),
    max_decode_seq_len=max_decode_seq_len,
    max_encoder_seq_len=max_encoder_seq_len)

batch_size = len(input_tokens)
if cuda_graph_pad_size != -1:
    # If cuda graph can be used, pad tensors accordingly.
    # See `capture_model` API for more details.
    # vLLM uses cuda graph only for decoding requests.
    batch_size += cuda_graph_pad_size

# Tokens and positions.
if cuda_graph_pad_size:
    input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
                                       self.runner.device,
                                       self.runner.pin_memory)

token_types_tensor = async_tensor_h2d(token_types, torch.long,
                                       self.runner.device,
                                       self.runner.pin_memory) \
                                        if token_types else None

Paged Attention (非連続なVRAMの利用)

前述のように、vLLMの実装上の大きな工夫点はOSのメモリ管理のようにページング形式のVRAM管理をする点です。 これにより、あらかじめ確保したKVキャッシュ領域に対してオンデマンドで必要な分だけメモリを割り当てるような動作となります。

メモリはデフォルトでは16トークンを1つのブロックとして扱っています。 入力プロンプトを含むシーケンスの長さがこのブロックサイズを超えると、そのシーケンスは複数のブロックに分割されます。

具体的に動作を見てみると、例えば入力プロンプト長が9、生成トークン数が12の場合、 合計で21トークンのキャッシュとなり、これは2つのブロックに分割されます。

block_tablesには2つのブロックがある
プロンプト+出力済みトークンが16以上ある

一方で、入力プロンプト長が9、生成トークン数が1の場合ブロック数は1つです。

block_tablesには1つのブロックがある
プロンプト+出力済みトークンが16に満たない

Pythonコード側からは一連の系列としてトークンが確認できますが、実際にGPUで処理が行われる際にはこのブロックに対応するメモリ領域を参照しているということです。

モデル実行の直前のブロックテーブルと思われるデータを参照してみると、 上記画像で示されたブロック番号と想定される値と同様の値が確認できます。 torchのテンソルとして参照可能なようになっています。

> decode_meta.block_tables

tensor([[13058, 13056,     0,  ...,     0,     0,     0],
        [13057, 13055,     0,  ...,     0,     0,     0],
        [13054,     0,     0,  ...,     0,     0,     0],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0]],
        device='cuda:0',
        dtype=torch.int32)

こういった解析はデバッガを使ったときの利点のひとつですね。


おわりに

以上、vLLMのコードを読んでみたログになります。 そもそも脱落せずにこの記事をここまで読んでいる方がどの程度か気になりますが…。

vLLMは大きなプロジェクトではありますが、OSSの中では比較的規模の大きすぎないものに思われます。 やってみた感想としては「意外と読み進めやすい」といったところです。 ただ簡単かと言うとそんなことはなく、プロジェクトの大きさによる複雑さよりは、 活動の活発さによって変化が激しいことによるコードの追いづらさがあるのではないかと思われます。

有名なソフトウェアやライブラリも、大半の人は使いこそすれ、その実装まで追いかけることは少ないでしょう。 どういったロジックが動いているのかを実装から知ることは、 そのソフトウェアやライブラリを使う上での理解の一助になります。

また、コードを読むことでそのソフトウェアやライブラリの思想の理解にもつながるはずです。 プロダクトの開発にもきっと役に立つのではないでしょうか?

それでは、お疲れ様でした。