NTTドコモR&Dの技術ブログです。

Vertical Federated Learning の通信コストを大幅削減!SparseVFL について解説

NTTドコモ R&D Advent Calendar 2023 の1日目の記事です。

井上と申します。アメリカのシリコンバレーにあるドコモの子会社,DOCOMO Innovations, Inc. (DII) で Principal Data Scientist として機械学習の研究開発に従事しています。

DII は Amazon Web Service とパートナーシップを組み,Federated Learning (連合学習, 略して FL) の開発と実用化に取り組んでいます。昨年のアドベントカレンダーでは Vertical Federated Learning (VFL) のチュートリアルを書きました。VFL は訓練の過程においてサーバとクライアントの間で発生する通信コストが大きいという問題があります。そこで私たちは通信コストを削減する手法 SparseVFL を開発し,2023年8月にカリフォルニアの Long Beach で開催された KDD FL4Data-Mining '23: International Workshop on Federated Learning for Distributed Data Mining で発表してきました。本記事では SparseVFL について解説していきます。

なお,SparseVFL の論文はこちら。コードも公開しています:論文用コードAWS実装版

サマリ

SparseVFL は4つの要素で構成されます。

  1. クライアントモデルの出力層で ReLU を採用
  2. Embedding のL1ノルムをロス関数に加算
  3. ランレングス符号化
  4. Gradient のインデックス送信をスキップ (Maked-gradient)

従来の VFL と比較して通信データ量を 68-81% 削減。10 Mbps の場合に訓練時間を 63% 削減。

Vertical Federated Learning のおさらい

VFL の概要は昨年度のアドベントカレンダーの記事をご参照ください。本論文の前提条件は次の通りです。

  • 1台のサーバと複数台のクライアントが存在する
  • クライアントはサーバと通信を行う
  • クライアントは別のクライアントと通信を行わない
  • 各クライアントとサーバの通信は同期的に実行される

図1: VFL の概要

通信コストの問題

通信コストとして通信データ量と訓練時間の両方が考えられますが,本章では通信データ量について説明します。通信データ量の詳細と訓練時間については論文中4.4章をご参照ください。

訓練の過程においてクライアントとサーバは Embedding と Gradient を繰り返し交換します。Epoch 数を  I ,数値データ型のビット長を  Q ,クライアント  m\in[1,M] がサーバへ送信する Embedding  \mathbf{E}_{m} について 1 epoch 分の要素数を  |\mathbf{E}_{m}| とすると通信データ量  S_{m}^{\alpha} [byte] は下記の式で表現できます。

 \displaystyle
S_{m}^{\alpha}=2|\mathbf{E}_{m}|IQ/8

Embedding  \mathbf{E}_{m} のアップロードと Gradient  \mathbf{G}_{m} のダウンロードが発生しますが, |\mathbf{E}_{m}|=|\mathbf{G}_{m}| なので先頭の係数2でまとめています。末尾の8はビットからバイトへの変換です。要素数  |\mathbf{E}_{m}| はサンプル数  N とクライアント出力次元  D_m の積  D_mN になります。もしミニバッチで訓練する場合は要素数  |\mathbf{E}_{m}| はバッチサイズ  b_m ,バッチ分割個数  N/b_m ,クライアント出力次元  D_m の積  b_mD_m\times N/{b_m}=D_mN と表現できます。

もし MNIST データセットを用いる場合,通信データ量はどうなるでしょうか?  I=100 epochs, N=60000 サンプル,クライアントの出力は仮に元の解像度のまま  D_m=28\times28 次元として, Q=32 ビットとすると

 \displaystyle
S_{m}^{\alpha}=2\times(60000\times28\times28)\times100\times32/8\space\mathrm{[byte]}=37.6\space\mathrm{[GB]}

となります。37.6 GB もあると回線に大きな負荷をかけて訓練時間が長くなりそうです。次の章で通信コストを削減する関連研究について紹介していきます。

関連研究

通信データ量の削減にはいろいろなアプローチがあります。

  • クライアントの出力次元  D_m を減らす:上記の例では元の解像度28x28のまま出力していましたが,クライアントモデルの出力次元を削減すれば通信コストは削減可能です。デメリットとして,出力次元が小さすぎると精度が劣化する恐れがあります。
  • ビット長  Q を減らす:Embedding と Gradient を表現する浮動小数点のビット長を32から16や8へ減らせば通信コストは削減可能です。デメリットとして量子化ノイズが増大して精度が劣化する恐れがあります。詳細は Courbariaux et al. の研究をご参照ください。
  • 不要な情報を減らす:例えば Gradient のうち絶対値の小さい要素はモデルパラメータの更新にあまり影響しないので削除しても精度に影響しないという研究があります。このアプローチのデメリットとして情報を捨てすぎると精度が劣化する恐れがあります。詳細は Wangni et al. の研究をご参照ください。
  • クライアントの元データを圧縮してサーバへ一度にすべて預ける:クライアントが元データを PCA や AutoEncoder で圧縮した後,その圧縮データをサーバへすべて送ってしまうというアプローチがあります。送信後はサーバだけで訓練が完結するので通信コストは最初の1回分だけで済みます。デメリットとして,クライアントモデル (ここでは PCA や AutoEncoder) が VFL 全体のロスを考慮できないため,訓練過程でクライアントモデルが最適化されず,精度が劣化する恐れがあります。詳細は Khan et al. の研究をご参照ください。
  • Embedding を再利用する:Epoch 毎に通信するのではなく,例えば次の Epoch ではクライアントからサーバへ Embedding を送信せずに,サーバが前の Epoch で受領済みの Embedding を再利用することで通信データ量を削減する,という研究があります。デメリットとして収束が遅くなる,すなわち訓練時間が長くなる恐れがあります。このアプローチは特に VFL の各クライアントがサーバと非同期で通信を行う場合においては有効に働く可能性があり,SparseVFL を含む他のデータ削減手法と組み合わせ使うことも可能と考えられます。本論文は同期通信を前提としていることから比較対象外としています。詳細は Chen et al. の研究をご参照ください。
  • 部分的な情報をクライアント間で共有して通信データ量を削減する:このアプローチは本論文と前提条件が異なるため比較対象外としています。詳細は Castiglia et al. をご参照ください。

SparseVFL

Embedding と Gradient をスパース化する,言い換えると行列内の要素に多くの0を作ることができれば,最後にランレングス符号化を適用することで効率的な情報圧縮が可能ではないか? と考えました。私たちが提案する手法 SparseVFL は次の4つの要素で構成されます。(1) クライアントモデルの出力層で ReLU を採用,(2) Embedding のL1ノルムをロス関数に加算,(3) ランレングス符号化,(4) Gradient のインデックスをスキップ (Maked-gradient) 。各構成要素を順番に説明していきます。

図2: Embedding も Gradient もスパースにしたい

クライアントモデルの出力層で ReLU を採用

ReLU はよく使われる活性化関数で,負値をすべて0にして出力します。スパース化したいという目的には適しています。SparseVFL はクライアントモデルの最終層として ReLU を採用します。クライアントモデルについては最終層以外の構造は任意で自由に設計できます。

図3: ReLU

Embedding のL1ノルムをロス関数に加算

一般的によく使われるL1正則化はモデルパラメータのL1ノルムをロス関数に加算します。これによりモデルパラメータをスパースにすることができ,モデルの表現力を抑制することで過学習を防ぐ効果があります。しかし SparseVFL はモデルパラメータではなく Embedding をスパースにしたいので,Embedding のL1ノルムをロス関数に加算します。サーバがクライアント  m から受領したサンプル  n の Embedding  \mathbf{h}_{n,m} とそれに対応するラベル  \mathbf{y}_n を用いてロス関数を次の通り設計します。

 \displaystyle
L:=\frac{1}{N}\sum_{n=1}^{N}l(\theta_0,\mathbf{h}_{n,1},...,\mathbf{h}_{n,M};\mathbf{y}_n)+ 
    \frac{\lambda}{MN}\sum_{m=1}^{M}\sum_{n=1}^{N}||\mathbf{h}_{n,m}||_1

第一項は CrossEntropy 等の任意のロス関数です。第二項が Embedding のL1ノルムです。第一項も第二項もサーバ内で計算可能です。

ここまでの2つの構成要素 (ReLU と Embedding のL1ノルム) を採用することで Embedding をスパースにすることができます。図4のような事前実験でクライアントが出力した Embedding です。黒い箇所だけが非0の値です。この例の場合,4次元目の縦一列に着目すると1つの要素しか非0の値を持っていないので効率的な情報圧縮が期待できます。

図4: スパース Embedding

ランレングス符号化

ランレングス符号化とは,例えば  [0,0,0,0,1,1] という元データが与えられた場合に,値とその連続する個数 (0が4 個,1が2個連続して並ぶ) の情報  [0,4,1,2] に変換する手法であり,この符号化より情報量を削減することが可能です。SparseVFL におけるランレングス符号化は従来のランレングス符号化と微妙に異なり,元の数列を非0の値の数列とそのインデックスの情報へ符号化します。例えば図4の Embedding の場合には,まず縦方向に走査して一次元の数列  [0.1, 0.3, 0.2, 0.0, 0.0, 0.0, 0.0, 0.1] に変換した後に,非0の値の数列  [0.1,0.3,0.2,0.1] と非0の値が始まるインデックス  [0,7] および0の値が始まるインデックス  [3] へ符号化します。元々8個の数値で表現していたものを,7個の数値に劣化なく可逆変換して情報量を削減しています。ところで,Embedding も Gradient も 行列なので,サンプル方向 (縦) に走査するか,特徴量方向 (横) に走査するかで圧縮率に影響すると思われます。これは後述の実験1で検証します。クライアントは Embedding を非0の値の数列とそのインデックスをサーバへアップロードします。

図5: ランレングス符号化と Masked-gardient

Gradient のインデックスをスキップ (Maked-gradient)

実は Gradient のうち,Embedding が0である要素と同じ箇所にある要素は削除してもモデルパラメータの更新式に影響しないことが分かりました。証明は論文中4.2章をご参照ください。したがって図5の例ではサーバはクライアントへ  [-0.2,-0.1, -0.4,-0.2] だけを返却し,これ以外の値は破棄することができ,かつインデックス自体も送信する必要がありません。これは符号化した Embedding のインデックスと符号化した Gradient のインデックスが一致するためです。クライアントは符号化した Embedding のインデックスを保持しておいて Gradient の復号時に利用します。Gradient のインデックスの情報が省けるので全体として大幅な通信データ量削減が可能です。図5の例では8個の数値を4個の数値に劣化なく可逆変換して情報量を削減しています。

このアイディアを私たちは Masked-gradient と呼んでいます。Masked-gradient は ReLU の微分の特性に基づいています。最適化手法として少なくとも Adam, SGD, RMSprop のいずれかの場合に Masked-gradient が有効です。これ以外の最適化手法については未検証です。

通信データ量

 \mathbf{E}'_m を Embedding における非0の値の数列, \mathbf{H}_m を連続する非0の先頭のインデックス, \mathbf{T}_m 連続する0の先頭のインデックスとして,SparseVFL の通信データ量は下記の通りです。

 \displaystyle
S_{m}^{(\beta)}=(2|\mathbf{E}'_m|+|\mathbf{H}_m|+|\mathbf{T}_m|)IQ/8

ここで,Masked-gradient 後に残った Gradient を  \mathbf{G}'_m とすると, |\mathbf{E}'_m|=|\mathbf{G}'_m| であることから上記式の  |\mathbf{E}'_m| 係数は2となっています。 \mathbf{H}_m および  \mathbf{T}_m は Embedding  \mathbf{E}'_m のアップロードとともに送信されますが,Gradient  \mathbf{G}'_m のダウンロード時には送信されません。

ところで,SparseVFL にとって  S_{m}^{(\beta)} が最悪となるケースはなんでしょうか? それは Embedding を一次元化したベクトルに非0と0がひとつずつ交互に出現する場合です。このとき, |\mathbf{E}'_m|=|\mathbf{G}'_m|=|\mathbf{H}_m|=|\mathbf{T}_m|=0.5|\mathbf{E}_m| となり,Embedding のアップロード時の通信データ量は従来の1.5倍に増えますが,Masked-gradient のおかげで Gradient のダウンロード時の通信データ量は従来の0.5倍で済みます。したがって訓練過程においては SparseVFL の通信データ量は常に従来の VFL の通信データ量以下,すなわち  S_{m}^{(\beta)}\le S_{m}^{(\alpha)} が成立します。推論過程における通信データ量および対処は論文中4.4章をご参照ください。

なお,インデックス  \mathbf{H}_m および  \mathbf{T}_m は整数であることから  Q よりも短いビット長で表現することも可能です。この場合  S_{m}^{(\beta)} はさらに小さくなります。

実験1: 構成要素の検証

SparseVFL のどの構成要素がデータ量削減に貢献しているか評価しました。下記構成要素の組み合わせを比較しています。

  • クライアントモデル出力層:ReLU, SeLU, eLU
  • Embedding のノルム:L1, L2, なし (-)
  • ランレングス符号化の走査方向 (Traversal):サンプル方向 (Vertical), 特徴量方向 (Horizontal)

なお,各クライアントモデルは1層の Linear と ReLU。サーバモデルは2層の Linear とReLU。としています。図6の通り,クライアントモデル出力層に ReLU を採用することで 371 MB,Embedding のL1ノルムで 128 MB,走査をサンプル方向にとることで 54 MB の通信データ量削減に貢献していることがわかりました。この例においてデータセットは Adult を使用,各手法で精度 ROC-AUC はほぼ同じです。詳細は論文中の Table 2 をご参照ください。

図6: 構成要素の検証

実験2: 通信データ量

  • データセット:Adult, Wine Quality, Covertype
  • クライアント数:3

という条件下で,下記の手法を比較評価しました。

  • 従来のVFL (Original):出力次元  D_m=8 (Adult) または  D_m=4 (Wine Quality, Covertype), Q=32 ビット (Float 32)。
  • クライアント出力次元 (Dim- D_m ):Original から段階的に  D_m=6, 4, 3, 2 と減らしていきます。
  • ビット数 (Q- Q ):Original から  Q=16,8 (Float 16, Float 8) と減らしていきます。
  • PCA, AutoEncoder:クライアントが元データを  D_m=8 次元に圧縮して一度にすべてサーバへアップロード。
  • Top-W-16:Wangni et al. の手法に基づいて絶対値小さな Gradient を削除して,残った値とそのインデックスに変換します。保持する値の個数  W を様々なパターンで試しています。Embedding に対しては下記 SparseVFL-16 と同じ手法を適用しています。L1ノルムの重みを様々なパターンで試しています。 D_m=8  Q=16
  • SparseVFL-16:ReLU,L1ノルム,サンプル方向に走査, D_m=8  Q=16 。L1ノルムの重み  \lambda を様々なパターンで試しています。

図 7の横軸は通信データ量,縦軸は精度 (2値分類は ROC-AUC,多値分類はF1スコア) です。プロットが左上にあるほど通信データ量の削減に有効な手法と言えます。SparseVFL-16 は Original とほぼ同精度を維持しながら,通信データ量を 68-81% 削減できました。また,同一データ量においては他の通信データ量削減手法よりも高い精度を実現しています。

図7: 通信データ量と精度

実験3: 訓練時間 (マシン単体での実験)

訓練時間  \tau はクライアントとサーバの間の通信回線のスループットにも依存します。スループットが 10 Mbps であれば 63% (365.0→135.5 [s]) の,100 Mbps であれば 4% (80.0→77.1 [s]) の訓練時間削減となりました。

ただし,論文では1マシンの1プロセス内でクライアントとサーバの処理をシーケンシャルに処理しており,訓練時間はクライアントの実訓練時間,サーバの実訓練時間,Embedding と Gradient のデータ量と仮定したスループットから算出される理論的な通信時間,をもとに計算しました。実際に分散したマシンを用意してインターネット回線の環境で実験すると通信時間は変わってきます (後述) 。

表中の  m=0 はサーバ, m=1,2,3 はクライアントです。 T_m はサーバまたはクライアントでの実訓練時間, S_m は各クライアントがサーバと交換する Embedding と Gradient のデータ量です。 \tau_{\space 10M} はスループット 10 Mbps, \tau_{\space 100M} はスループット 100 Mbps におけるそれぞれの合計訓練時間です。

Algorithm  m  T_m [s]  S_m [MB]  \tau_{\space 10M} [s]  \tau_{\space 100M} [s]
Original 0
1
2
3
34.9
12.4
13.4
12.4
-
396
396
396
365.0 80.0
SparseVFL-16 0
1
2
3
50.9
16.6
19.6
17.9
-
85
80
58
135.5 77.1

実験4: 訓練時間 (マシン複数での実験)

論文発表後に,DII の Intern の Xiaoyu Wang さんと Senior Research Engineer の守屋さんが SparseVFL をAWS環境 へ実装しました。クライアント数は2, 3, 4,データセットは Adult です。各クライアントは異なるリージョンに配置されています。この実験では SparseVFL-16 は Original に比べて 9.77-12.85% の訓練時間を削減できました。


#Clients
Original
ROC-AUC
Original
 \tau [s]
Sparse-16
ROC-AUC
SparseVFL-16
 \tau [s]
Ratio
ROC-AUC drop
Ratio
Time reduce
2 0.8117 1,187 0.8019 1,071 1.21% 9.77%
3 0.8887 1,575 0.8866 1,405 0.24% 10.79%
4 0.9007 1,758 0.9001 1,532 0.07% 12.85%

詳細は GitHub レポジトリをご参照ください。

まとめ

SparseVFL は精度を維持したまま VFL の通信コスト (データ量と訓練時間) を削減できました。

参考文献

最後に… DOCOMO Innovations, Inc. ってどんなところ?

オフィスは Sunnyvale にあります (2023年春頃に Palo Alto から引っ越してきました)。NTT OneVision Center というビルで,NTT Research 他 NTT グループ企業のシリコンバレー拠点も同じ建物に同居しています。サンフランシスコ国際空港から33分,サンノゼ国際空港から12分くらいです。

DII は FL を活用した企業間連携を模索していますので,シリコンバレーにお立ち寄りの際はお気軽にお声がけください! 論文中に私 (井上) 宛のメールアドレスがあります。

建物外観

著:井上 義隆