2.4K Views
May 19, 24
スライド概要
AI・機械学習を勉強したい学生たちが集まる、京都大学の自主ゼミサークルです。私たちのサークルに興味のある方はX(Twitter)をご覧ください!
2024前期輪読会#3 グラフニューラルネットワーク3章3.3~3.5 理学部3回生 山下 素数 0
GNNの頂点分類問題に対する訓練と推論の方法 GNNの頂点分類問題に対する問題設定としては主に転導的学習と帰 納的学習がある。 ⚫ 転導的学習とは、グラフ𝐺 = (𝑉, 𝐸, 𝑋)と教師ラベルの与えられて いる頂点集合𝑉𝐿 があって、ラベルの与えられていない頂点に対し てラベルを予測する問題設定 ⚫ 帰納的学習とは、グラフ𝐺 = (𝑉, 𝐸, 𝑋)と各頂点に対して教師ラベ ルが与えられたときに、別のグラフ𝐺’ = (𝑉’, 𝐸’, 𝑋’)の頂点に対して ラベルを予測する問題設定 GNNでは、転導的学習と帰納的学習はほぼ同じような方法で訓練・ 推論できる。 それでは転導的学習と帰納的学習の訓練・推論方法について見てい く。 1
GNNの訓練と推論の方法 転導的学習での訓練方法 問題設定から、グラフ𝐺 = (𝑉, 𝐸, 𝑋)と教師ラベルの与えられている頂 点集合𝑉𝐿 がある。 次の手順で転導的学習を行うことができる 1. ラベルの与えられている頂点𝑣 ∈ 𝑉𝐿 に対して頂点埋め込み𝑍𝑣 を計算 する。 2. ラベルの与えられている頂点𝑣 ∈ 𝑉𝐿 に対して、頂点埋め込み𝑍𝑣 から 頂点ラベルを予測する。例えば、𝑌𝑣 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑍𝑣 )、𝑌𝑣 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑀𝐿𝑃𝜃 (𝑍𝑣 ))などと予測する。 3. 損失関数は交差エントロピーを用いて誤差逆伝搬によってパラメー ターの最適化を行う。 4. ラベルの与えられていない頂点𝑣に対して2.と同じような方法でラ ベルを予測する。 2
GNNの訓練と推論の方法 帰納的学習での訓練方法 問題設定から、グラフ𝐺 = (𝑉, 𝐸, 𝑋)と頂点集合𝑉 に対する教師ラベル、 別のグラフ𝐺’ = (𝑉’, 𝐸’, 𝑋’)が与えられる。 次の手順で帰納的学習を行うことができる 1. グラフ𝐺の頂点𝑣 ∈ 𝑉に対して頂点埋め込み𝑍𝑣 を計算する。 2. グラフ𝐺の頂点𝑣 ∈ 𝑉に対して、頂点埋め込み𝑍𝑣 から頂点ラベル を予測する。例えば、𝑌𝑣 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑍𝑣 )、𝑌𝑣 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑀𝐿𝑃𝜃 (𝑍𝑣 ))などと予測する。 3. 損失関数は交差エントロピーを用いて誤差逆伝搬によってパラ メーターの最適化を行う。 4. グラフ𝐺’の頂点𝑣に対して2.と同じような方法でラベルを予測す る。 3
GNNの訓練と推論の方法 転導的学習でのGNNのアーキテクチャーの例 GNNのアーキテクチャーとして 例えば、 GCNを2層重ねたGNN、 1 መ መ 𝑌𝑣 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝐴𝑅𝑒𝐿𝑈( 𝐴𝑋𝑊 )𝑊 2 ) を用いて転導的学習をすることができる ここで、 𝐴が隣接行列、𝐷が次数行列、 1 1 − − 𝐴መ = 𝐷 2 𝐴𝐷 2 𝑋はグラフ𝐺の頂点特徴量 である 4
GNNの訓練と推論の方法 GNNの訓練方法 GNNの訓練方法には、すべての頂点に対して頂点埋め込みの計算と 分類を行うフルバッチの方法と、一度にグラフの一部の頂点の計算 と分類を行うミニバッチの方法の二種類がある フルバッチの方法 ミニバッチの方法 メリット 行列計算が高速にできる デメリット メモリ消費量が多い メリット メモリ消費量が少ない デメリット 実装がフルバッチの方法より 複雑。実行時間もフルバッチ の方法より長い 5
GNNの訓練と推論の方法 GNNのミニバッチの方法での訓練 ミニバッチの方法では、グラフの頂点集合𝑉の部分集合𝐵 ⊂ 𝑉に対し て頂点の分類を行う。 GNNの仕組み上、𝐵の近傍の頂点に対しても頂点埋め込みを計算す る必要がある。 𝑁1 (𝑣) 𝑖層目 1層のGNNの1層目の出力を計算するために 1層目の入力として 頂点埋め込み 𝑁1 𝐵 = )𝑣(𝑁 𝐵∈𝑣ڂの頂点の頂点特徴量が必要で、 帰納的に考えると、 集約 𝐿層(𝐿 ≥ 2)のGNNのL層目の出力を計算するために 1層目の入力として 𝑖 + 1層目 𝑁𝐿 𝐵 = 𝐿𝑁∈𝑣ڂ−1 (𝐵) 𝑁(𝑣)の頂点の頂点特徴量 𝑣 が必要 6
GNNの訓練と推論の方法 GNNのミニバッチの方法での訓練 以上の話をまとめると、 𝐿層のGNNをミニバッチの方法で訓練するときは、次のように頂点埋 め込みを計算していく。 ⚫ 1層目の入力に𝑁𝐿 𝐵 の頂点特徴量を渡し、𝑖層目(𝑖 = 1, … , 𝐿)の出 力として、𝑁𝐿−𝑖 (𝐵)の頂点埋め込みを計算する(𝑁0 𝐵 = 𝐵とした) 問題設定として、複数のグラフが与えられることがある(例えば化合 物が複数個など) このようなときは、与えられたデータが連結でない1つのグラフと考 えてミニバッチ法を用いることができる(Pytorch Geometricなど) 例えば、バッチサイズ32で学習するなら、32個のグラフを連結でな い1つの部分グラフと思って学習することを繰り返す。 7
異種混合グラフへのGNNの適用 頂点の種類が複数個のグラフに対するGNNの適用 頂点の種類が複数個のグラフに対してGNNを適用する一番簡単な方 法は、 その頂点がどの種類なのかを表す頂点特徴量をワンホットベクトル で加える ことである。 種類ごとに特徴量ベクトルの次元が異なる場合は、 線形層を間に挟むことによって特徴量ベクトルの次元が同じになる ようにする。 (頂点𝑣が種類𝑐なら、頂点特徴量として𝑊𝑐 𝑋𝑣 を使うということ) 8
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) ( 関係グラフ畳み込みネットワーク(RGCN) 辺の種類が複数ある場合のGNNを3つ紹介する。 一つ目の関係グラフ畳み込みネットワーク(RGCN)は、辺の種類ご とに別々に畳み込みを行う。 𝑙 𝑙層目の頂点𝑣の埋め込みℎ𝑣 を 𝑙+1 ℎ𝑣 = 𝜎(𝑊 𝑙+1 𝑙 ℎ𝑣 1 + 𝑁𝑟 𝑣 𝑟∈𝑅 𝑊𝑟 𝑙+1 𝑙 ℎ𝑢 ) 𝑢∈ 𝑁𝑟 𝑣 と定義する。ただし、 𝜎はシグモイド関数、 𝑅は辺の種類全体の集合、 𝑁𝑟 (𝑣)は頂点𝑣に種類𝑟の辺で隣接する頂点集合 である 9
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) 辺条件付き畳み込み(ECC) 辺条件付き畳み込み(ECC)は、辺の種類に応じて重み行列を変更す 𝑙 るGNNであり、頂点𝑣における埋め込みℎ𝑣 を 𝑙+1 ℎ𝑣 1 𝑙+1 = 𝜎( + 𝑏(𝑙+1) ) 𝑀𝐿𝑃𝜃, 𝑙+1 (𝐹 𝑢,𝑣 )ℎ𝑢 |𝑁(𝑣)| 𝑢∈𝑁(𝑣) と定義する。 ただし、辺𝑒に対して、𝐹𝑒 は辺の特徴量を表すベクトル(例えばワン ホットベクトルにとれる) MLPは全結合層 10
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) 異種混合グラフ注意ネットワーク(HAN) 異種混合グラフ注意ネットワーク(HAN)では、メタパスを用いて遠 くの頂点からの情報も注意機構(attention)で一度に集約する。 ただし、メタパスとは、頂点の種類𝐴1 , … , 𝐴𝑙+1 、辺の種類𝑅1 , … , 𝑅𝑙 に 対して、種類𝑅1 , … , 𝑅𝑙 の辺を通って種類𝐴1 , … , 𝐴𝑙+1 の頂点をたどるパ スのこと。 異種混合グラフ注意ネットワーク(HAN)のアーキテクチャーは次の ように定義される。 11
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) 異種混合グラフ注意ネットワーク(HAN) 実用上はマルチヘッドア テンションの方が良い 頂点𝑖からメタパスを通ってたどりつ くことができる頂点の頂点埋め込み を集約する。(メタパスは𝑃種類ある) 頂点分類問題を考えているので 損失関数は交差エントロピー https://arxiv.org/abs/1903.07293 より引用 12
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) 余談(metapath2vec) metapath2vecはDeepWalkにおけるランダムウォークをメタパス 上を動くランダムウォークに変え、異種混合グラフに対して頂点埋 め込みを作る手法。ここではグラフは無向グラフとして考える。 メタパスを通って移り合える頂点の頂点埋め込みの類似度が似たも のになるように学習する。 softmax関数の分母の部分の計算量が大きいので、ネガティブサン プリングをよく用いる。 https://paperswithcode.com/paper/metapath2vecscalable-representation-learning のpdfより引用 13
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) 余談(metapath2vecでのネガティブサンプリング) ネガティブサンプリングでは、頂点埋め込みの類似度に関する問題 を直接は解かない。 頂点𝑢, 𝑣が与えられたときに、頂点𝑢と頂点𝑣がメタパスを通って移 り合えるかどうかという2値分類問題を解くことにする。 正例はメタパス上を通るランダムウォークによるサンプリングをし てパス上の頂点𝑢 と、頂点𝑢の(パスを部分グラフとして見た時にそ の部分グラフで) 近傍にあるパス上の頂点をとってくることによっ て、負例はランダムサンプリングによって取ってくる。 このようにするとソフトマックス関数の分母の部分の計算が要らな くなる。以下の𝑂(𝑋)の頂点𝑐𝑡 と頂点𝑣に関する期待値が目的関数。 https://paperswithcode.com/paper/metapath2vec-scalable-representation-learning のpdfより引用 14
異種混合グラフへのGNNの適用(辺の種類が複数ある場合) 余談(metapath2vecでのネガティブサンプリング) 参考までに、論文に書いてあるネガティブサンプリングの疑似コー ドを引用しておく。 https://paperswithcode.com/paper/metapath2vecscalable-representation-learning のpdfより引用 15
同変性とメッセージ伝達による定式化の意義 メッセージ伝達によるGNNがなぜ良いのか メッセージ伝達によるGNNの定式化以外にも、隣接行列𝐴を用いて 𝑊𝐴で頂点ラベルを予測するモデルを考えることができる。 しかし、このように定式化した場合には以下のような問題点がある。 ⚫ 異なるグラフに対して予測ができない ⚫ 重みの行列𝑊の大きさが大きいので計算量が大きい ⚫ 同変でない これから同変に関する説明をする。 同変というのは、大雑把には同じ構造のグラフなら同じ出力になる ことを指す。 例えば、同変であるためには、頂点の番号の振り方に出力結果に依 存しない必要がある。(必要条件) 16
同変性とメッセージ伝達による定式化の意義 グラフの同型性の定義 グラフ𝐺1 = (𝑉1 , 𝐸1 , 𝑋)と𝐺2 = (𝑉2 , 𝐸2 , 𝑌)が同型であるとは、 全単射𝑓: 𝑉1 → 𝑉2 が存在し、 𝑋𝑣 = 𝑌𝑓(𝑣) が全ての𝑣 ∈ 𝑉1 について成り立ち、 𝑢, 𝑣 ∈ 𝐸1 ֞ 𝑓 𝑢 , 𝑓 𝑣 ∈ 𝐸2 が全ての𝑢, 𝑣 ∈ 𝑉1 について成り立つこと。このような全単射𝑓を同 型写像といい、𝑣 ∈ 𝑉1 と𝑓 𝑣 ∈ 𝑉2 を同型な頂点という。 17
同変性とメッセージ伝達による定式化の意義 グラフの同変性の定義 グラフを受け取り頂点埋め込み集合を返す関数 𝐺: 𝑉, 𝐸, 𝑋 → 𝑍 ∈ 𝑅 𝑉 × 𝑑 が同変であるとは、任意の同型なグラフ𝐺1 = 𝑉1 , 𝐸1 , 𝑋 と𝐺2 = (𝑉2 , 𝐸2 , 𝑌)と任意の同型写像𝑓: 𝑉1 → 𝑉2 について、 𝑔 𝐺1 𝑣 = 𝑔 𝐺2 𝑓(𝑣) がすべての𝑣 ∈ 𝑉1 について成り立つことをいう。 (𝑙) 𝑙 集約関数がℎ𝑣 と{{ℎ𝑢 |𝑢 ∈ 𝑁(𝑣)}}の関数のメッセージ伝達型のGNN は同変であることが示せる(証明は省略) LSTMによって集約するメッセージ伝達型GNNは頂点の順番に依存 するため同変ではない。 18
19