3.9K Views
October 23, 23
スライド概要
深層ボルツマンマシンは,生成モデルとして有望だが,その学習が難しく,事前学習が必要とされてきた.本稿では,深層ボルツマンマシンの学習において,不偏マルコフ連鎖法に基づく勾配推定法を提案する.提案法は,勾配推定を安定かつ高速に行うことができるため,事前学習なしでの深層ボルツマンマシンの学習を可能にする.実験では,画像の生成モデルとして深層ボルツマンマシンを提案法を用いて学習し,既存手法と比較してその有効性を検証する.
不偏マルコフ連鎖モンテカルロ法を用いた 深層ボルツマンマシンの一気通貫学習 谷口 尚平、鈴木 雅大、岩澤 有祐、松尾 豊 1
深層生成モデル • 生成モデルの技術が急速に発展 例:画像生成,音声合成,チャットボット https://twitter.com/goodfellow̲ian/status/1084973596236144640?s=20 2 https://j.gifs.com/Y7mBPW.gif
深層生成モデル • モデリング・学習の方法の違いにより, 様々な種類の深層生成モデルが存在 • それぞれに長所・短所がある • 本研究では,ボルツマンマシンを扱う 3 https://ieeexplore.ieee.org/document/8354080
v : 可視層(データ) 深層ボルツマンマシン h : 隠れ層 θ : パラメータ(重み) Deep Boltzmann machine, DBM • エネルギーベースモデルの一種 • 2値変数(v, h)の確率分布を,エネルギー関数Eを用いて以下のように定義 p (v, h , h ; θ) = (1) (2) exp (−E (v, h , h ; θ)) (1) ∑ṽ,h̃(1),h̃(2) exp (−E (ṽ, h̃(1), h̃(2); θ)) E (v, h , h ; θ) = − v W h (1) (2) (2) ⊤ (1) (1) 4 −h (1)⊤ (1) (2) W h
v : 可視層(データ) 深層ボルツマンマシン h : 隠れ層 θ : パラメータ(重み) サンプリング • DBMの確率分布から厳密なサンプリングを行うことは難しい • 代わりに,ギブスサンプリングを用いることが一般的 p (vi = 1 ∣ h ; θ) = (1) (2) p (hk (1) p (hj (1) (1) σ (Wi,: h ) = 1 ∣ h ; θ) = σ (h (1)⊤ (1) = 1 ∣ v, h ; θ) = σ (v (2) (2) W:,k ) (1) W:,j ⊤ + (2) (2) Wj,: h ) 条件付き分布からのサンプリングを繰り返すことで 漸近的に厳密なサンプルに近づく 5
v : 可視層(データ) 深層ボルツマンマシン h : 隠れ層 θ : パラメータ(重み) 学習 • 対数尤度 log p (v; θ) を最大化することで学習 • 以下の確率的勾配を用いてSGDなどで最適化 𝔼 𝔼 = (v; log p θ)] [ (v) pdata pdata(v)p(h ∣ v; θ) [− ∇θ E (v, h , h )] − (1) (2) 𝔼 ∇θ 6 p(ṽ, h̃; θ) [− ∇θ E (ṽ, h̃ (1) , h̃ )] (2)
v : 可視層(データ) 深層ボルツマンマシン h : 隠れ層 θ : パラメータ(重み) 学習 ∇θ = pdata(v) [log p (v; θ)] pdata(v)p(h ∣ v; θ) [− ∇θ E (v, h , h )] − (1) (2) p(ṽ, h̃; θ) [− ∇θ E (ṽ, h̃ (1) , h̃ )] (2) pdata (v) p (h ∣ v) 𝔼 データのエネルギーを下げて,モデルのエネルギーを上げることで学習 𝔼 𝔼 p (v, h) 7
v : 可視層(データ) 深層ボルツマンマシン h : 隠れ層 θ : パラメータ(重み) 学習 • (1) (1) 重み Wi,j の勾配は,その重みが接続する変数viとhj の統計量のみによって 決まる ∇Wi,j(1) pdata(v)p(h ∣ v; θ) [vi = ⋅ (1) hj ] − p(ṽ, h̃; θ) [ṽi ⋅ (1) h̃j ] (2) Wi,j についても同様 𝔼 𝔼 𝔼 • (v; log p θ)] [ pdata(v) 誤差逆伝播法に依存しない学習方法 8
深層ボルツマンマシン 課題 ∇θ = pdata(v) [log p (v; θ)] pdata(v)p(h ∣ v; θ) [− ∇θ E (v, h , h )] − (1) (2) • 第1項,第2項ともに解析的に計算できない 𝔼 𝔼 𝔼 ➡ 近似計算が必要 9 p(ṽ, h̃; θ) [− ∇θ E (ṽ, h̃ (1) , h̃ )] (2)
深層ボルツマンマシン 長所 短所 • 表現力が高い • 学習が難しい • 学習に誤差逆伝播法を使わない • 特に多層の場合には,各層ごとに 学習するなどの事前学習が必要 • 生物学的妥当性が高い 10
研究課題 DBMの勾配推定 ∇θ 𝔼 pdata(v)p(h ∣ v; θ) [− ∇θ E (v, h , h )] − (1) (2) p(ṽ, h̃; θ) [− ∇θ E (ṽ, h̃ (1) , h̃ )] (2) この勾配を正確かつ効率的に推定する方法の開発に取り組む 𝔼 𝔼 = pdata(v) [log p (v; θ)] 11
既存手法① 有限ステップギブズサンプリング 期待値をとる確率分布に対して,ギブスサンプリングをTステップ回す ∇θ log p (v; θ) = p(h ∣ v; θ) [− ∇θ E (v, h ≈− , h )] − (1) (1) (2) ∇θ E (v, hT , hT ) (2) p(ṽ, h̃; θ) [− ∇θ E (ṽ, h̃ (1) (1) (2) ∇θ E (ṽT, h̃T , h̃T ) + , h̃ )] (2) 𝔼 𝔼 Tステップのギブスサンプリングで得られるサンプル 12
既存手法① 有限ステップギブズサンプリング https://openreview.net/forum?id=r1eyceSYPr 課題 Tのチューニング • Tが小さいと推定量にバイアスが乗る • Tが大きいと計算効率が下がる 先行研究でも学習途中から性能の劣化が 起こるという報告あり ➡ Tを適応的に決めたい 13
• 確率変数 x ∼ P (x)に関する期待値 [ f(x)]を推定することを考える • xtをMCMCをtステップ回したサンプルとすると,P(xt) → P(x) as t → ∞ • このとき,以下が成り立つ [ f (x0) + 𝔼 [ f(x)] = 𝔼 𝔼 不偏MCMC 14 ∞ ∑ t=1 ] f (xt) − f (xt−1)
• 確率変数 x ∼ P (x)に関する期待値 [ f(x)]を推定することを考える • xtをMCMCをtステップ回したサンプルとすると,P(xt) → P(x) as t → ∞ • ytがxtと同じ分布に従うと仮定すると [ f (x0) + 𝔼 [ f(x)] = 𝔼 𝔼 不偏MCMC 15 ∞ ∑ t=1 ] f (xt) − f (yt−1)
• 確率変数 x ∼ P (x)に関する期待値 [ f(x)]を推定することを考える • xtをMCMCをtステップ回したサンプルとすると,P(xt) → P(x) as t → ∞ • ytがxtと同じ分布に従うと仮定すると • さらに,t ≥ τにおいてxt = yt−1が成り立つと仮定すると [ f (x0) + 𝔼 [ f(x)] = 𝔼 𝔼 不偏MCMC 16 τ−1 ∑ t=1 ] f (xt) − f (yt−1)
不偏MCMC • 2つの相関したマルコフ連鎖(xt, yt)を xt = yt−1となるまで回し,以下の量を 計算すれば,期待値 [ f(x)]の不偏推定量が得られる f (x0) + τ−1 ∑ t=1 f (xt) − f (yt−1) 𝔼 • x = (v, h)とすれば,DBMにも不偏MCMCを適用できる 17 https://statisfaction.wordpress.com/2017/08/14/unbiased-mcmc-with-couplings/
既存手法② ギブスベースの不偏MCMC • マルコフ連鎖の組(xt, yt)をギブスサンプリングに基づいて設計 • 不偏MCMCを用いることで,学習途中から性能が劣化する現象を抑えられる ことが報告されている 18 https://openreview.net/forum?id=r1eyceSYPr
既存手法② ギブスベースの不偏MCMC 課題 • 状態xが高次元になると,xt = yt−1となる までにかかるステップ数が指数的に増える ➡ 次元の呪い 不偏MCMCによる勾配推定の不偏性を保ちつつ 高次元にもスケールする手法が必要 19
提案法 メトロポリス=ヘイスティングス 1. x0を初期化する 2. 候補 x′を一様分布からサンプリング 繰り返す 3. 確率 min (1, exp (E (xt; θ) − E (x′; θ))) で xt+1 ← x′ とし,    それ以外の場合は xt+1 ← xt とする 20
提案法 メトロポリス=ヘイスティングス 1. x0を初期化する 2. 候補 x′を一様分布からサンプリング 繰り返す 3. 確率 min (1, exp (E (xt; θ) − E (x′; θ))) で xt+1 ← x′ とし, それ以外の場合は xt+1 ← xt とする E(x0)が非常に低い場合,高い確率で x0 = x1 が起きる    ➡E(x0)が低いx0で初期化すれば,不偏MCMCのステップ数を減らせるはず 21
提案法 局所最頻値による初期化 • 局所探索法を用いて,E(x)の局所最小解を求め,それをx0とする 1. x0 = (1) 2. x0 ← argmin {E(x0) ∣ 繰り返す (2) 3. x0 (1) (2) {x0 , x0 } を一様ノイズで初期化 (1) x0 ← argmin {E(x0) ∣ (2) x0 (2) x0 } (1) x0 } 有限回で局所最小解に収束 22
提案法 局所最頻値による初期化 • (1) DBMの場合は,x 0 = (2) v , h , ( 0 0 ) (2) x0 = (1) h0 とすれば,各ステップの最小化 問題が解析的に解ける argmin {E ∣ (2) v0,h (1) h0 } = (1W(1)h(1)≥0, 1h(1)⊤W(2)≥0) argmin {E ∣ v0, h } = 1v⊤W(1)+W(2)h(2)≥0 (2) h(1) 23
実験 トイデータ 重みをランダムに初期化したDBMに対して 不偏MCMCを回した際のステップ数を比較 • 提案法は次元の呪いを回避できる • 局所探索法のステップ数 (T) は高次元で 増加するが,増え方は線形以下 • 不偏MCMCのステップ数 (τ) は高次元では ほぼ常に1 24
実験 画像生成 • MNISTとFashion-MNISTに対してDBMを end-to-endに学習させる • 既存手法は,層ごとの事前学習なしでは 学習に失敗する • 提案法はend-to-endに学習可能 25
実験 画像生成 • MNIST, CIFAR-10に対して提案法で DBMを学習し,FIDで評価 • WGANとほぼ同等の性能を発揮 26
実験 欠損補完 • DBMは欠損値の補完タスクにも使える • vの一部を固定し,エネルギー関数が 小さくなるv, hを局所探索法で探索 MNIST • 得られた値で欠損を補完する 27 Fashion-MNIST
まとめ • 深層ボルツマンマシンは学習が誤差逆伝播法に依存しないモデルとして有望だ が,勾配の推定が難しく,既存法ではend-to-endな学習ができなかった • 本研究では,不偏MCMCのテクニックを用いて勾配の不偏推定を実現 • MCMCに局所最頻値で初期化されたメトロポリス・ヘイスティングス法を用 いることで,効率的に不偏MCMCの計算を行うことができる • 画像生成のベンチマークでWGANなどと同等の性能を実現 28
今後の課題 • 今回の実験では,重みが密なDBMを用いたため,現時点では高解像度画像の ような超高次元なデータに対しては,パラメータ数が爆発する • 畳み込みのような,疎な重み結合をもつモデルを用いることで,高解像度 画像データに対しても適用可能になるはず 29
論文 • ICML 2023に採択 • 英語論文をarXivで公開中 https://arxiv.org/abs/2305.19684 30