NTTドコモR&Dの技術ブログです。

データのつながりを解き明かす!Graph Embeddingの考え方と適用例の紹介

はじめに

こんにちは、サービスイノベーション部の山路です。 普段の業務ではデジタルマーケティングに関する技術検討や研究開発を行っています。本記事では、近年取り組んでいるグラフデータを活用したレコメンド技術に関して、その根幹となるGraph Embeddingの考え方や適用例について簡単にですがご紹介します。

本記事の対象読者

  • グラフや機械学習の基本的な用語や概念について理解している方
  • Graph Embeddingの手法や実データへの適用方法の考え方に興味がある方
    • 本記事では具体的な実装方法については記載しませんが、参考となる実装例のリンクを貼っていますので実装方法が気になる方は適宜参照してください。

Graph Embeddingについて

Graph Embedding(以下GE)はグラフデータをその意味的な情報を保持しながら低次元のベクトル空間に埋め込む手法です。GEとして扱うことで以下のような利点があります。

グラフの利点

  • 現実世界には、ネットワーク構造を持つデータが数多く存在します。例えば、人間関係ネットワーク、取引ネットワーク、交通ネットワークなどが挙げられます。これらはそのネットワーク構造に基づいてデータが蓄積される傾向にあるため、データをグラフ形式で表現・解釈することでデータの本質を捉えやすくなります。
  • これまでテーブル形式で表現されることの多かったデータ、例えばユーザーとアイテムの関係性といったデータにおいても、グラフを活用することによりユーザー間やアイテム間の関係性を表現しやすくなり、データをより深く正確に理解することができるようになります。

図1: グラフ形式で扱うことでユーザ間やアイテム間に関係性情報を付与しやすくなる

Embedding(埋め込み表現)の利点

  • 埋め込み表現として扱うことで、一般的な機械学習手法を適用しやすくなるため、ノードをラベリングするノード分類やノード間の関係性を予測するリンク予測といったタスクに適用することができます。
  • グラフデータは通常、高次元で複雑な構造を持ちます。例えば、ノード間の接続情報を示す隣接行列はスパースかつ高次元であることが多いため、これを低次元のベクトルとして扱えるようにすることで計算コストを削減できます。
  • 埋め込み表現を2次元や3次元に変換することで、データの可視化や解釈が容易になります。

本記事で取り扱うGE手法

GE手法は多くありますが、本記事では、Node2Vec、GCN、Knowledge Graph Embeddingの3つのGE手法を扱います。これらは、アプローチの異なる基本的な手法であるため、各手法の考え方を理解することは今後さまざまなユースケースにおいてGEを活用する上で有用です。

  • Node2Vec: グラフのネットワーク構造(隣接情報)に基づき、ノードの埋め込み表現を生成します。
  • GCN(Graph Convolutional Network): DNNにおける「畳み込み」を非構造データであるグラフに応用し、隣接ノードの特徴を集約してノードの埋め込み表現を生成します。GNNの一種になります。
  • Knowledge Graph Embedding: 主に知識グラフを対象とし、上記2つの手法とはやや異なりエンティティ間の関係性を捉えることを目的とし、エンティティとリレーションの埋め込み表現を生成します。

※ エンティティ/リレーションは、ノード/エッジを知識グラフの文脈で言い換えたものでほぼ同義と考えて問題無いです。


手法概要

Node2Vec

Node2Vecは、ランダムウォークによって近傍ノードをサンプリングした後、Word2Vecでも使われているskip-gramによってノードの埋め込み表現を学習する手法です。ランダムウォークにより、グラフ上の隣接ノードを効率的にサンプリングでき、計算コストを抑えながらグラフの隣接関係を反映した埋め込みが生成されます。
類似手法であるDeepWalkでもランダムウォークに基づいてサンプリングを行いますが、Node2Vecではローカルな構造(近接ノード)とグローバルな構造(遠方ノード)のどちらを重視するかをパラメータで制御可能です。これにより、特定のタスクに合わせた柔軟な埋め込み表現の学習が可能になります。

図2: Node2Vecのイメージ図。ランダムウォークによって抽出された系列に含まれるノードが"近い"ノードとなる。pやqによってサンプリング方法を調整することで"近さ"の定義を変更できる。

実装例

DGLでのNode2Vecの実装例があります*1

DGLで提供されているnode2vec_random_walk関数の引数pqによってランダムウォークのサンプリング方法を調整できます(簡単にはp > 1、q < 1でグローバル重視、p < 1、q > 1でローカル重視となります)。

以下実装例からの抜粋

from dgl.sampling import node2vec_random_walk

pos_traces = node2vec_random_walk(
    self.g, batch, self.p, self.q, self.walk_length, self.prob
)

GCN

GCN(Graph Convolutional Network)は、非構造データであるグラフに対してニューラルネットワークの考え方を適用したGNN(Graph Neural Network)の一種です。具体的には隣接ノードの特徴量を集約し(メッセージパッシングと呼ばれます)各ノードの特徴量を更新していきます。これにより、各ノードの埋め込みはグラフ構造を反映した特徴量として生成されます。Node2Vecとは違いノードの特徴量を考慮でき、例えばノードに属性情報(年齢、性別など)を特徴量としてもたせることで埋め込みに反映させることができます。また、関係性が複数種類あるような異種グラフを扱う場合は、拡張手法であるRGCN(Relational GCN)が用いることで異種情報を保持した学習も可能です。

図3: 2-LayerのGCNのイメージ図。例えばノード「2」に特徴量を集約する場合、1Layer目(左)でノード「5」にはノード「6」の特徴量が伝搬され、2Layer目(真ん中)でその情報を含んだノード「5」の特徴量がノード「2」に伝搬されるため、Layer数=何hop先までの特徴量を集約するか、であると言える。

実装例

DGLでのGCN実装のドキュメントおよび実装例があります。

DGLではGraphConvが提供されており、PyTorchでニューラルネットを構築する要領でnn.Moduleを継承してGCNを構築できます。

from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

DGLではGCN以外にも様々なモデルの実装例(PyTorchの例)が提供されており、例えば前述のRGCNの実装例もあります。

Knowledge Graph Embedding

Knowledge Graph Embedding(以下KGE)は、その名の通り、主に知識グラフを対象にした埋め込み手法です。知識グラフとは、エンティティ(例えば人名、商品名など)とその間の関係性をノードとエッジで表現したグラフデータを指します。 KGEは、知識グラフ補完のようなエンティティ間の関係性を予測することが主な目的であるため、基本的にはリンク予測のタスクに特化しています。エンティティとその関係性に基づいてスコア関数を定義し学習することで、最終的にはノードだけでなく関係性の埋め込みも生成されます。KGEは、新たな関係性の発見や推論などに活用できるため、知識グラフでなくともユーザーとアイテム間の好みや傾向を予測する推薦システムなどへも応用可能です。

図4: KGEの中でもベーシックな手法な1つであるTransEのイメージ図(左)。KGEではhead/relation/tailの低次元ベクトルに対しスコア関数を定義し学習する(右)

実装例

KGEに特化してDGLが提供しているライブラリがあります。 特にscore_fun.pyには、各KGE手法の根幹ともなるスコア関数が定義されており、以下のようにそれぞれの違いが確認できます*2

# TransEのスコア関数の例
def edge_func(self, edges):
    head = edges.src['emb']
    tail = edges.dst['emb']
    rel = edges.data['emb']
    score = head + rel - tail
    return {'score': self.gamma - th.norm(score, p=self.dist_ord, dim=-1)}

# DistMultのスコア関数の例
def edge_func(self, edges):
    head = edges.src['emb']
    tail = edges.dst['emb']
    rel = edges.data['emb']
    score = head * rel * tail
    return {'score': th.sum(score, dim=-1)}

GE手法の適用例

紹介したGE手法について、ノード分類とリンク予測タスクへ適用してみます。 また、ここでは各手法の適用例を紹介することを目的とし、各種パラメータチューニングや埋め込み表現の次元数拡大などによる予測精度向上のための最適化はある程度までで省略していますので予めご了承ください。

ノード分類

ノード分類は、グラフ内の各ノードにラベルを付与するタスクです。本記事では、ノード分類モデルのベンチマークとして広く用いられている論文の引用関係を表現したCoraデータセットを使用します。

図5: coraデータセットのグラフ概要。論文の引用関係をノード間のエッジで表現。

グラフ統計量は以下の通りです。

  • ノード数: 2708
  • エッジ数: 5429
  • 特徴量次元数: 1433
  • ラベル数(論文の種類): 7

手法

Node2VecとGCNの2つの手法で比較します。ちなみに上記で紹介した実装例のリンク先はCoraデータセットを使った実装になっていますので参考にしてみてください。

結果

各手法により32次元の埋め込み表現を生成した後、ロジスティック回帰によるノード分類(7つのラベルの分類)を行いAccuracyを算出してみました。

  • Node2Vec(ランダムウォーク(系列)の長さは30、各ノードに対するランダムウォークの数は10):
    • バランス設定: 0.70 (p=1.0, q=1.0)
    • グローバル重視: 0.68 (p=1.5, q=0.5)
    • ローカル重視: 0.70 (p=0.5, q=1.5)
  • GCN: 0.73

いずれも70%程度の精度で正しくラベリングできており、引用関係情報を用いることで論文の分類がある程度可能であることがわかります。また、引用関係のみならず特徴量も集約しているGCNの方がより高い精度となりました。

次に、生成された埋め込み表現について500ノードをサンプリングしTensorBoadを用いて次元削減(PCA)し可視化しました。同じラベルの論文が近い位置に配置されることが期待されます。

初期の特徴量ベクトル、Node2Vecによる埋め込み、GCNによる埋め込みをそれぞれ可視化しており、同じ色は同じラベルの論文になります。

図6: 初期特徴量ベクトル(左)、 Node2Vecによる埋め込み(ローカル重視)(中央)、GCNによる埋め込み(右)

初期特徴量はばらつきが大きいですが、Node2Vecによってややまとまりがみえ、GCNによってさらにまとまりが見られます。

考察

適用結果から、GCNはNode2Vecよりも高い分類精度を示し、埋め込み表現もより期待されるものになっていると言えそうです。GCNの優位性は、関係性情報のみならず各ノードが保持する特徴量も考慮できることやそれらを階層的に集約しながらグラフ全体の構造を捉えることができることにあります。特にCoraデータセットのように比較的小規模なグラフであればよりその優位性が働くと考えられます。

一方、Node2Vecは関係性情報のみを用いた埋め込み表現であることや、"近さ"を定義するランダムウォークのサンプリングに依存する部分もありGCNほどグラフの構造を捉えることは難しいようです。ただし、"近さ"を定義する上でローカル/グローバル重視のバランスを調整できる柔軟性があり、用途に応じた埋め込み表現を生成したい場合に有用かもしれません。例えばCoraデータセットの場合引用関係にある論文が同じラベルであるケースが多いためローカル重視の方がやや良い精度となったと考えられます。その点でNode2Vecのサンプリング方法は、埋め込みを生成するという用途のみならず、ノードをサンプリングするという用途にも使えそうです。また、計算コストの面では一般的にはNode2Vecの方が軽量であるため(もちろんグラフ構造や探索パラメータなどに依存しますが)、特に特徴量を持たないノードの埋め込みを学習する場合には有用かもしれません。

リンク予測

リンク予測は、グラフ内のノード間の関係性の有無を予測するタスクです。本記事では、Wikipediaから構築した知識グラフを用いてリンク予測を行います。

Wikipediaから知識グラフを構築

今回はDBPediaを使って生成します*3。DBPediaは、Wikipediaの情報をLOD(Linked Open Data)として公開するコミュニティプロジェクトであり、SPARQLというクエリ言語をエンドポイントに対して発行することでトリプレット形式のデータを取得できます。今回は、DBPediaの日本語版であるDBPedia Japaneseからデータを取得し、知識グラフを構築しました。

具体的には以下要件により構築しています。

  • 全国の事業会社のうち各業種の代表的な屋号を計1000ほどピックアップしトリプレットのheadとする
  • Wikipediaのページ内リンクhttp://dbpedia.org/ontology/wikiPageWikiLink 」をトリプレットのrelationとする
  • Wikipediaにてカテゴリとして分類されているキーワードhttp://ja.dbpedia.org/resource/Category 」を接頭語に持つURI*4をトリプレットのtailとする
  • 完全に切り離されたグラフが生成された場合は大きな1つのグラフのみを残す

つまり、ある事業会社のWikipediaページ内にリンクされているカテゴリ情報について、事業会社名とカテゴリ情報の2部グラフとして構築します。 例えば「NTTドコモ」をheadとした場合の上記要件のクエリは以下のように記述できます*5

SELECT ?s ?p ?o
WHERE {
    ?s ?p ?o .
    FILTER (?s = <http://ja.dbpedia.org/resource/NTTドコモ>)
    FILTER (?p = <http://dbpedia.org/ontology/wikiPageWikiLink>)
    FILTER (regex(?o, "^http://ja.dbpedia.org/resource/Category+"))
}

DBPediaが提供しているクエリエディタで実際にクエリを発行してみると2024年10月時点では以下のようなトリプレットが得られます。

図7: DBPediaから取得できるトリプレット

知識グラフとして表すと以下のようになります。

図8: トリプレットを知識グラフとして表した場合の例

上記例を見ていただければ分かりますが、ページ内に含まれるリンクを取得しているためノイズになりうるデータが含まれることに注意が必要です。

参考に、構築されたグラフの1/10スケールでの概形になります。※ 規約上の理由で概形のみの表示になります。

図9: グラフ全体の概形。上側が事業会社、下側がカテゴリ情報。

グラフ全体の統計量は以下の通りです。

  • ノード数: 2620
    • 事業会社(720)/カテゴリ(1900)
  • エッジ数: 4920
    • wikiPageWikiLinkのみ

手法

KGE単体の場合とGCNとKGEを組み合わせる場合(GCN+KGE)の2つを比較します。GCN+KGEはエンコーダーデコーダーモデルの考え方に基づき、まずGCNによるエンコードで埋め込み表現を生成し、その後KGEによるデコードでリンク予測に基づく埋め込み表現を生成します。GNNを活用したリンク予測の実装方法はこちらを参考ください*6。また、KGEは様々な手法がありますが本記事ではリンク予測のベンチマークが比較的良いとされているDistMultを使います。

結果

埋め込みの次元数は32次元としています。リンク予測の評価として一般的に用いられるMRR(Mean Reciprocal Rank)(Filtered*7)を算出してみました。

  • DistMultのみ: 0.32
  • GCN + DistMult: 0.51

MRRはグラフ規模に大きく依存し、今回扱ったグラフが小規模であることを踏まえるとそこまで高い精度であるとは言えませんが、GCN+DistMultの方がより高い精度となりました。

また、リンク予測の定性的な評価として「NTTドコモ」に対する各カテゴリのスコア関数の値上位5件を算出してみました。今回リンク予測としてDistMultを使っているためhead * relation * tailに基づく値を算出しています*8。また「リンク予測」ということで、実際に「NTTドコモ」のWikipediaページ内に含まれているカテゴリデータは除外しています。上位のカテゴリが、「NTTドコモ」のページ内リンクに含まれるカテゴリ情報として望ましいのであれば良い結果と言えます。 ※ 規約上の理由での事業会社名が含まれるデータはマスクしています。

図10: 「NTTドコモ」ノードに対するカテゴリノードのリンク予測スコア上位。DistMultのみ(左)と、 GCN+DistMult(右)。「予測」のためすでにページ内に含まれているカテゴリデータは除外。

「インターネット・サービス・プロバイダ」や「日本の電気通信事業者」が含まれているGCN+DistMultの方が良さそうです。「東証プライム上場企業」「日経平均株価」「TOPIX_100」などは以前まで*9は該当していたため「そのカテゴリに分類されうるか」という観点では性能的に良さそうですが、2024年10月時点では知識データとしては誤ったデータになるので注意が必要です。このあたりは用途に依って良し悪しの判断が必要な部分になります。

さらに、定性的な評価として埋め込みノードの近傍探索をしてみます。同じ種類(事業会社 or カテゴリ)のノードについて類似ノードが近傍にあれば良い結果と言えます*10。ここでは「カテゴリ」ノードについて「日本のファーストフード」ノードの近傍カテゴリ(cos類似度)の上位10件を算出してみました。
※ 規約上の理由での事業会社名が含まれるデータはマスクしています。

図11: 「日本ファーストフード」ノードに対するカテゴリノードの類似度上位。DistMultのみ(左)と、 GCN+DistMult(右)。

GCN+DistMultの方が「牛丼チェーングループ*11」がより上位にきていたり「ハンバーガー店」がランクインしていることを踏まえるとより良い埋め込み表現が生成されていそうです。

考察

適用結果から、GCNとKGEを組み合わせた方がより高い予測精度を示し、類似カテゴリが近くに埋め込まれていることから埋め込み表現もより期待されるものになっていると言えそうです。

特に今回扱った知識グラフのようにノイズデータが多い場合は、関係性情報のみを用いるKGE単体よりも、各隣接ノードを階層的に集約するGCNの方が補正が働き良い結果が導かれやすいのかもしれません。

一方、今回はエッジを1種類のみしか設定していません*12。これは例えば「NTTドコモ」の例について「永田町」と「日本の携帯電話事業者」が同じ関係性として扱われることになり、「日本の携帯電話事業者」に関して類似事業会社を特定したい場合に「永田町」はノイズになります。もし「永田町」には「所在地」、「日本の携帯電話事業者」には「業種」が関係性情報として設定されていれば、それぞれの関係性を正しく区別した上で埋め込みされるため、KGE単体でも精度が良くなると考えられます。また、計算コストの面ではKGEの学習に使われるスコア算出部分はノードとエッジによるスコア関数で定義される行列演算であり一般に高速に計算できるというメリットがあります。グラフに含まれるノイズが少ないと思われる場合や関係性情報がある程度細かく設定されている場合などには、KGEを単体でリンク予測してみることを検討しても良さそうです。

まとめ

本記事では、3つの基本的なGE手法の概要や考え方について示しました。また、適用例を通じてロジックの違いがどのような結果をもたらすかを簡単に示しました。一方で、私自身普段の業務でもグラフデータの扱いにおける難しさは、グラフの構造やデータの性質によって結果がかなり左右される点にあると感じており、その中で重要なことは各手法の考え方を理解した上で、自分のユースケースに適したアプローチを見つけることだと考えています。本記事は簡単なご紹介ではありましたがこれからGraph Embeddingに取り組もうとされているみなさまの一助となれば幸いです。

*1:DGL(Deep Graph Library)とは、GNN(Graph Neural Network)の開発・学習・評価の補助を目的としたPythonのオープンソースライブラリです。

*2:厳密には、edge_func関数に渡すedgesに対する前処理にも各ロジックによる違いが含まれています。

*3:ドキュメントから知識グラフを生成する方法としてLlamaIndexを使った実装があるようですが、今回はクエリによってある程度制御可能なDBPediaを使って小規模でシンプルな構造のグラフを構築することとします。それぞれの比較までは実施していません。

*4:LODの目指す目標の1つは、ウェブ上のデータをリンクし異なるデータセットをつなぐことであり、URIを使ってリソースを一意に特定できるようにすることが推奨されています。

*5:こちらのクエリではセマンティックウェブにおける標準的な記述形式に則り"subject/predicate/object"のそれぞれの頭文字を使っていますが、head/relation/tailとほぼ同義と考えて問題無いです。

*6:ちなみにリンク先マニュアルのDotPredictorがデコーダー部分にあたります。DotPredictorではノード特徴量の内積を使っていますが、KGEを使う場合はスコア関数をここで定義します。

*7:Filteredとは簡単に言うと、評価においては不正解のエッジだがグラフとしてはエッジとして張られている場合にそのエッジをテストデータから除外する評価方法を指します。

*8:ここではheadは「NTTドコモ」埋め込み、relationは「wikiPageWikiLink」埋め込み、tailは各カテゴリノードの埋め込みになり、最終的にはsigmoidにより0~1に収めています

*9:2024年10月時点では日本電信電話の子会社になります。

*10:「事業会社」ノードと「カテゴリ」ノードの類似度のように別の種類のノードの同士の近傍探索の場合は、リンク予測の性能評価で実施したようにスコア関数の定義に則って距離計算をする必要があります。

*11:「〇〇グループ」はWikipediaにてカテゴリデータとして分類されていました。

*12:ちなみにエッジの種類が複数ある場合は異種グラフになるため、前述のRGCN+DistMultを使うことになります。実装例もあります。