NTTドコモR&Dの技術ブログです。

【Python】PyTorch で作る Vertical Federated Learning

NTTドコモ R&D Advent Calendar 2022 の1日目の記事です。

井上と申します。アメリカのシリコンバレーにあるドコモの子会社,DOCOMO Innovations, Inc. (DII) でシニアデータサイエンティストとして機械学習の研究開発に従事しています。

現在,DII は Amazon Web Services, Inc. とパートナーシップを組み,Federated Learning (連合学習, FL) の開発に取り組んでいます。 AWS Partner Network (APN) Blog の記事もご覧ください。

本記事は,FL の中でも,特に Vertical Federated Learning (VFL) を PyTorch を用いて作り上げていくチュートリアルです。 なお,本記事末尾に職場の紹介を載せていますので「シリコンバレーとか DII ってどんなところ?」と気になる方はご覧ください!

Federated Learning

FL は複数のマシンが協力して機械学習モデルを作っていく技術であり,特に各マシンが保有するデータに他のマシンがアクセスしないまま (データを保護したまま) 機械学習することを可能にしています。FL は大きく2種類あり,Horizontal Federated Learning (HFL) と Vertical Federated Learning (VFL) に区別されます。代表的なユースケースは下記の通りです。

  • HFL: ひとつのサーバと大量のクライアントデバイスが存在し,1人のデータが1台のデバイスに保存されているようなユースケースに適しています。各クライアントが保有するデータは,列構成は同じですが,ユーザIDが異なります。Google が TensorFlow Federated を提供しており,モバイルキーボードの予測モデルの学習に活用。Google の漫画を読むのが分かりやすいです。
  • VFL: サーバの役割を担う1つの企業とクライアントの役割を担う複数の企業が存在し,大量のユーザデータが各クライアント企業に保存されているようなユースケースに適しています。各クライアントが保有するデータは,列構成が異なるものの,ユーザIDリストは共通しているとします。例えば,医療機関 (診察データ),スーパーマーケット (日々の食事データ),スポーツジム (運動データ) の3社が生データをシェアせずに,3社に共通するユーザに対して,協力してモデル (総合的なヘルスケアAI) を学習するような場合です。OpenMined の PyVertical 等のライブラリが存在します。

HFL vs VFL

プライバシー保護の強化

HFL と VFL のいずれもサーバとクライアントがインターネットを介してモデルファイルもしくは中間データファイルを交換します。ここで交換されるデータはクライアントが保有する生データではありませんが,悪意のあるサーバまたは悪意のある第三者がここで交換されるデータを解析することで,クライアントが保有するデータの中身 (プライバシーに関する情報) が明らかになってしまうリスクが存在します。プライバシー保護を強化するために,例えば下記の手法を FL に組み合わせて使うことができます。

  • Homomorphic Encryption (準同型暗号): 暗号化されたデータ同士の加算と乗算を,およびこれらを組み合わせて機械学習計算を可能にする暗号技術
  • Differential Privacy (差分プライバシー): データにノイズを加えて特定の個人があるデータの集合に含まれているかどうかの判別を困難にする技術

本記事ではこれらのプライバシー保護技術についてはこれ以上は触れずに,VFL の構築方法のみを対象として説明していきます。

PyTorch で VFL を作ってみる

図のような VFL を作っていきます。1台のサーバと1台のクライアントが存在します。

VFL

なお,本来はサーバとクライアントは異なるマシンとして存在しているので,複数のマシンもしくは複数のプロセスを立ち上げるのが VFL としては正しいのですが,本チュートリアルでは1プロセス内でサーバとクライアントの挙動をシミュレートします。

データセット

Adult データセットを使います。これは1994年の米国国勢調査に基づいたデータで,各行が個人を示しており,性別,年齢,職業等の特徴量をもとに,収入が $50K を上回るかどうかを予測します (2値分類)。本記事では前処理としてカテゴリ列は One-Hot でエンコード,数値列は MinMax でスケーリングします。

import os
import random
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler

from tqdm.notebook import tqdm
import time

def load_data():
    tr = pd.read_csv('adult.data', header=None)
    te = pd.read_csv('adult.test', header=None, skiprows=1) # 1行目は不要

    h = [
        'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation', 'relationship', 
        'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', 'over_50k'
    ]

    tr.columns = h
    te.columns = h
    te.over_50k = te.over_50k.str.split('.', expand=True)[0] # 行末尾に不要なドットが含まれているので除去

    num_cols = ['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
    cat_cols = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country']
    lab_col  = 'over_50k'
    
    return tr, te, num_cols, cat_cols, lab_col

def create_encoders(dat, cat_cols, num_cols):
    encoders = dict()

    # カテゴリ列はOneHotEncoding
    for c in tqdm(cat_cols):
        enc = OneHotEncoder(handle_unknown='ignore')
        enc.fit(dat[c].astype(str).values.reshape(-1, 1))
        encoders[c] = enc

    # 数値列はMinMaxScaler
    for c in tqdm(num_cols):
        scaler = MinMaxScaler()
        scaler.fit(dat[c].values.reshape(-1, 1))
        encoders[c] = scaler

    return encoders

def encode(dat, cat_cols, num_cols, encoders):

    for c in tqdm(cat_cols):
        out = encoders[c].transform(dat[c].astype(str).values.reshape(-1, 1))
        if not type(out) == np.ndarray:
            out = out.todense()
        keys = [f'{c}_{i}' for i in range(out.shape[1])]
        dat[keys] = out
        dat = dat.drop(c, axis=1)

    for c in tqdm(num_cols):
        out = encoders[c].transform(dat[c].values.reshape(-1, 1)).flatten()
        dat[c] = out

    return dat

class AdultDataset(Dataset):
    def __init__(self, x, y):
        # 本来はxはクライアント,yはサーバ側で保有します
        self.x = torch.Tensor(x.values)
        self.y = torch.Tensor(y.values)
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx,:], self.y[idx]

def preprocess(tr, te, cat_cols, num_cols, lab_col):
    encoders = create_encoders(tr, cat_cols, num_cols)
    tr = encode(tr, cat_cols, num_cols, encoders)
    te = encode(te, cat_cols, num_cols, encoders)
    
    tr[lab_col] = tr[lab_col].replace({' <=50K': 0, ' >50K': 1}) # スペースが含まれていた
    te[lab_col] = te[lab_col].replace({' <=50K': 0, ' >50K': 1}) # スペースが含まれていた
    
    tr_x = tr.drop(lab_col, axis=1)
    tr_y = tr[lab_col]
    te_x = te.drop(lab_col, axis=1)
    te_y = te[lab_col]
    
    tr_ds = AdultDataset(tr_x, tr_y)
    te_ds = AdultDataset(te_x, te_y)
    
    tr_dl = DataLoader(tr_ds, batch_size = 1024, shuffle=True)
    te_dl = DataLoader(te_ds, batch_size = 1024, shuffle=False)
    
    # positive weight
    pos_weight = (tr_y.shape[0] - tr_y.sum()) / tr_y.sum()
    
    return tr_dl, te_dl, torch.FloatTensor([pos_weight])

MLP

本記事ではまず比較用に 3層の MLP を作ります。この MLP をベースに VFL を作っていきます。なお,この MLP と VFL を正しく比較するために,モデルパラメータの初期値が両者で一致するように明示的に初期化処理を行っています。

seed = 42
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
class MLP(torch.nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(MLP, self).__init__()
        self.i2h = torch.nn.Linear(in_size, hidden_size)
        self.h2h = torch.nn.Linear(hidden_size, hidden_size//2) 
        self.h2o = torch.nn.Linear(hidden_size//2, out_size)   

        # 初期値を固定
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.i2h.weight.data)            
        torch.nn.init.ones_(self.i2h.bias.data)
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2h.weight.data)
        torch.nn.init.ones_(self.h2h.bias.data)
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2o.weight.data)
        torch.nn.init.ones_(self.h2o.bias.data)        
        
    def forward(self, x):
        h = self.i2h(x)
        h = F.relu(h)
        h = self.h2h(h)
        h = F.relu(h)
        o = self.h2o(h)
        return o
        
def train_MLP(tr_dl, te_dl, pos_weight):
    mlp = MLP(108, 32, 1)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)    
    
    for epoch in range(30):
        tr_loss = 0
        mlp.train()
        for i, (batch_x, batch_y) in enumerate(tqdm(tr_dl)):
            optimizer.zero_grad()
            pred_y = mlp(batch_x)
            loss = criterion(pred_y.flatten(), batch_y)
            loss.backward()
            optimizer.step()
            tr_loss += loss.item()
        print(f'Epoch: {epoch}, Training loss: {tr_loss:.4f}')

    te_loss = 0
    pred_y_list = []
    true_y_list = []
    mlp.eval()
    for i, (batch_x, batch_y) in enumerate(tqdm(te_dl)):
        pred_y = mlp(batch_x)
        loss = criterion(pred_y.flatten(), batch_y)
        te_loss += loss.item()
        pred_y_list.extend(torch.sigmoid(pred_y.flatten()).detach().tolist())
        true_y_list.extend(batch_y.detach().tolist())
    
    score = roc_auc_score(true_y_list, pred_y_list)
    print(f'Test loss: {te_loss:.4f}')
    print(f'Test ROC-AUC: {score:.4f}')

VFL

前述の MLP を分割して,入力層をクライアントへ,中間層と出力層をサーバへ配置します。VFL においてもモデルパラメータを明示的に初期化しています。

class ClientModel(torch.nn.Module):
    def __init__(self, in_size, hidden_size):
        super(ClientModel, self).__init__()
        self.i2h = torch.nn.Linear(in_size, hidden_size)

        # 初期値を固定
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.i2h.weight.data)            
        torch.nn.init.ones_(self.i2h.bias.data)
                
    def forward(self, x):
        h = self.i2h(x)
        h = F.relu(h)
        return h
    
class ServerModel(torch.nn.Module):
    def __init__(self, hidden_size, out_size):
        super(ServerModel, self).__init__()
        self.h2h = torch.nn.Linear(hidden_size, hidden_size//2) 
        self.h2o = torch.nn.Linear(hidden_size//2, out_size)   
        
        # 初期値を固定
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2h.weight.data)
        torch.nn.init.ones_(self.h2h.bias.data)
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2o.weight.data)
        torch.nn.init.ones_(self.h2o.bias.data)
        
    def forward(self, h):
        h = self.h2h(h)
        h = F.relu(h)
        o = self.h2o(h)
        return o

Forward-propagation

クライアントは保有するデータを読み込み,クライアントモデルの出力 (embedding) をサーバへ送信します。embedding は次元方向に圧縮されています。サーバは embedding を読み込んで推論を行います。ここで事前に embedding に対して emb.requires_grad_(True) と設定しておくことが大事なポイントです。次節でその理由を説明します。なお,サーバとクライアントがどこで分割されているのか分かりやすくするために,明示的に embedding を emb.pt としてファイルに保存してから交換するようにしています。クライアントが h と呼んでいる変数はサーバにおける emb と同一です。

def train_VFL(tr_dl, te_dl, pos_weight):
    client_model = ClientModel(108, 32)
    server_model = ServerModel(32, 1)
    clinet_optimizer = torch.optim.Adam(client_model.parameters(), lr=0.01)
    server_optimizer = torch.optim.Adam(server_model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)    
    
    for epoch in range(30):
        tr_loss = 0
        client_model.train()
        server_model.train()
        for i, (batch_x, batch_y) in enumerate(tqdm(tr_dl)):
            # クライアントがembeddingをサーバへ送信
            clinet_optimizer.zero_grad()
            h = client_model(batch_x)
            torch.save(h.detach(), 'emb.pt')
            
            # サーバがembeddingを受け取ってサーバモデルを更新
            server_optimizer.zero_grad()
            emb = torch.load('emb.pt')
            emb.requires_grad_(True) # あとでembeddingでの微分を取得できるように事前設定
            pred_y = server_model(emb)
            # (下のコードへ続く)

Backward-propagation

クライアントモデルのパラメータを  \theta_1 ,クライアントの出力 (=サーバの入力) を  h ,ロスを  l ,とするとこれらの関係は, l(h(\theta_1)) という合成関数で表現できます。「学習」とはロス  l が小さくなるように,クライアントモデルのパラメータ  \theta_1 をより妥当な値へ更新していく作業ですが,更新のためには  l  \theta_1 での微分  \frac{\partial l}{\partial \theta_1} が必要になります。しかし,VFL ではサーバとクライアントは異なるマシンとして存在しているため,アクセス可能なパラメータが下記の表の通りに異なります。

クライアントモデルのパラメータ  \theta_1 Embedding  h ロス  l
サーバ x o o
クライアント o o x

クライアントは  l にアクセスできないので, \frac{\partial l}{\partial \theta_1} を計算できません。そこで連鎖律を使います。合成関数  l(h(\theta_1))  \theta_1 での微分は

 \displaystyle
\frac{\partial l}{\partial \theta_1}=\frac{\partial l}{\partial h}\frac{\partial h}{\partial \theta_1}

のように, \frac{\partial l}{\partial h}  \frac{\partial h}{\partial \theta_1} の2つの微分に分割して,その積として求めることができます。この微分のことを gradient (勾配) と呼んだりするので,embedding の対になるものとして,以降は
gradient と呼ぶことにします。

サーバ側では  l のサーバモデルのパラメータ  \theta_0 での gradient  \frac{\partial l}{\partial \theta_0} を計算してサーバモデルを更新し,また,gradient  \frac{\partial l}{\partial h} を計算してクライアントへ返却します。前節で emb.requires_grad_(True) と設定することで,あらかじめロス  l の embedding  h における gradient  \frac{\partial l}{\partial h} を取得可能にしておきました。PyTorch の仕様上,これを事前に設定しないと gradient が取得できません。下記コードでは grad = emb.grad で gradient を取得しています。また,クライアントはこれを dldh という名前で受領しています.

クライアントは受領した gradient  \frac{\partial l}{\partial h} と,クライアント自身で計算可能なもうひとつの gradient  \frac{\partial h}{\partial \theta_1}とを積算して,新たな gradient  \frac{\partial l}{\partial \theta_1} を得ます。クライアントにおける h.backward() は内部で  \frac{\partial h}{\partial \theta_1} を計算しています。h.backward(dldh) のように引数に gradient dldh を与えることで,与えられた gradient との積  \frac{\partial l}{\partial \theta_1}=\frac{\partial l}{\partial h}\frac{\partial h}{\partial \theta_1} を計算できます。これでクライアントモデルも更新可能になりました。

コードと数式の対応関係を整理すると次の表の通りです。サーバとクライアントを区別するために明示的に別の変数名を割り当てていますが中身は同じです。

Embedding Gradient
サーバ emb grad
クライアント h dldh
ファイル名 emb.pt grad.pt
数式  h  \frac{\partial l}{\partial h}
            # (上のコードの続き。forループ内部で,サーバ処理の続き)
            loss = criterion(pred_y.flatten(), batch_y)
            loss.backward()
            server_optimizer.step()
            tr_loss += loss.item()
            
            # サーバがgradientをクライアントへ返却
            grad = emb.grad # gradient の取得
            torch.save(grad.detach(), 'grad.pt')
            
            # クライアントがgradientを受け取ってクライアントモデルを更新
            dldh = torch.load('grad.pt')
            h.backward(dldh)
            clinet_optimizer.step()
        
        # End of for all batches

        print(f'Epoch: {epoch}, Training loss: {tr_loss:.4f}')

    # End of for all epochs

    te_loss = 0
    pred_y_list = []
    true_y_list = []
    client_model.eval()
    server_model.eval()
    for i, (batch_x, batch_y) in enumerate(tqdm(te_dl)):
        # クライアントがembeddingをサーバへ送信
        h = client_model(batch_x)
        torch.save(h.detach(), 'emb.pt')
        
        # サーバがembeddingを受け取って出力
        emb = torch.load('emb.pt')
        pred_y = server_model(emb)
        loss = criterion(pred_y.flatten(), batch_y)
        te_loss += loss.item()
        pred_y_list.extend(torch.sigmoid(pred_y.flatten()).detach().tolist())
        true_y_list.extend(batch_y.detach().tolist())

        # Testingデータについてはback-propagationが不要
    
    score = roc_auc_score(true_y_list, pred_y_list)
    print(f'Test loss: {te_loss:.4f}')
    print(f'Test ROC-AUC: {score:.4f}')

複数クライアントの場合

本チュートリアルではサーバ1台・クライアント1台というシンプルな構成としています。VLF に参加するクライアントの数を増やしたい場合は,client_modelclient_optimizer をクライアントの数だけ用意します。このとき,client_model のネットワーク構造は各クライアントにおいて全く異なっていてもよく,クライアント自身が保有するデータの性質に合わせて設計できます。サーバは複数の embedding (emb1, emb2, emb3, ...) を受け付けたのちに,サーバモデルの入り口で emb = torch.cat((emb1, emb2, emb3, ...), dim=1) で結合すればOKです。ただし,各 embedding 同士でユーザ ID の順序は一致している必要があります。Backward においては各 embedding に対応する gradient grad1, grad2, grad3, ... を各クライアントに返却します。

実験

本記事の VFL は単に MLP を分割しただけなので,VFL と MLP の精度は一致するはずです。ロスと ROC-AUC で評価を行います。下記の通りコードを実行していきます。

tr, te, num_cols, cat_cols, lab_col = load_data()
tr_dl, te_dl, pos_weight = preprocess(tr, te, cat_cols, num_cols, lab_col)

MLP

st = time.time()
train_MLP(tr_dl, te_dl, pos_weight)
print(f'Time: {time.time()-st:.4f}')
100% 32/32 [00:00<00:00, 47.50it/s]
Epoch: 0, Training loss: 27.7643
100% 32/32 [00:00<00:00, 44.96it/s]
Epoch: 1, Training loss: 20.5746
...(中略)...
100% 32/32 [00:00<00:00, 45.35it/s]
Epoch: 29, Training loss: 17.2754
100% 16/16 [00:00<00:00, 49.30it/s]
Test loss: 9.5923
Test ROC-AUC: 0.9035
Time: 22.9737

VFL

st = time.time()
train_VFL(tr_dl, te_dl, pos_weight)
print(f'Time: {time.time()-st:.4f}')
100% 32/32 [00:00<00:00, 38.82it/s]
Epoch: 0, Training loss: 27.7643
100% 32/32 [00:00<00:00, 30.38it/s]
Epoch: 1, Training loss: 20.5746
...(中略)...
100% 32/32 [00:00<00:00, 38.27it/s]
Epoch: 29, Training loss: 17.2754
100% 16/16 [00:00<00:00, 43.07it/s]
Test loss: 9.5923
Test ROC-AUC: 0.9035
Time: 25.6535

MLP と VFL とで同じ精度が得られました 🙌

VFL はクライアントとサーバの間でファイルIOが発生している分,MLP より遅くなります。本来はさらに通信時間も追加で必要になります。

まとめ

本記事では PyTorch を用いて VLF を作る方法を解説しました。下記のような関数で適切に gradient を扱うところがポイントです。

  • emb.requires_grad_(True)
  • grad = emb.grad
  • h.backward(dldh)

補足:Split Learning

VFL は SplitNN (Split Learning) とも似ています。クライアント1台の VFL と クライアント1台の SplitNN は同一といえますが,複数クライアントの場合に挙動が異なります。VFL の場合は複数のクライアントが同時に動作できますが,SplitNN の場合は複数のクライアントは1台ずつ動作します。SplitNN において クライアント j は直前に処理を終えたクライアント j-1 のモデルパラメータを,サーバもしくはクライアント j-1 から受領して,クライアント j 自身にコピーした後に処理を開始します。VFL と SplitNN の差異は下記の通りです。

VFL SplitNN
各クライアントモデルの構造 任意 同一
各クライアントの並列計算 o x
サーバの入力層の次元 各クライアントの出力層の次元の合計 1クライアントの出力層の次元

参考文献

最後に… DOCOMO Innovatins, Inc. ってどんなところ?

※本記事中のオフィス所在地と外観写真は2022年12月当時のものです。2023年12月現在,オフィスはサニーベールにあります。

DII は AI,クラウド,無線,デバイスの研究開発,米国企業やスタートアップとの連携によるビジネス開発に従事しています。スタートアップの発掘や投資を専門に行う NTT DOCOMO Ventures, Inc. (NDV) のシリコンバレー支店も同居しています。

オフィスはカリフォルニア州のパロアルトという町にあります。サンフランシスコ国際空港から車で30分程度です。シリコンバレーには米国を代表するIT企業 (Google, Meta, Apple, NVIDIA等) の本社やスタンフォード大学があります。

オフィス外観
DII のオフィスは TIBCO Software の建物の一部に入居しています。

建物内部
建物内

Fish Market
オフィスの近くの Fish Market というレストランが好きです。 Crab Cioppino (写真左)や Garlic Prawn Linguine (右) が美味しいです。

ここ最近は海外出張も徐々に再開しつつあるかと思います。サンフランシスコ国際空港ご利用時はぜひお立ち寄りください。企業様からの協業のご相談も大歓迎です!