[DL輪読会]Train longer, generalize better: closing the generalization gap in large batch training of neural networks

>100 Views

July 21, 17

スライド概要

2017/7/21
Deep Learning JP:
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
1.

Train longer, generalize better: closing the generalization gap in large batch training of neural networks Elad Hoffer, Itay Hubara, Daniel Soudry 2017/7/21 発表者:金子貴輝 ※図表または式は明記しない場合,上記論文から引用

2.

概要と選考理由 • 概要 – 深層学習でバッチサイズを大きくできない謎にせまる – シャープミニマの研究の考察と対立して,ラージバッチは 学習が遅くなっているだけだと主張 – バッチサイズに連動する学習率の倍率,スケジューリング, バッチ正則化を修正すればスモールバッチと同等以上 – 学習をランダムポテンシャル上のランダムウォークでモデ ル化することで,更新量の共分散を修正する妥当性を主張 • 理解できませんでした • 選考理由 – 初期値からの重みの学習過程を評価する方法がある? 2

3.

イントロ • 一般化ギャップ – 損失関数が改善しなくなるまで訓練しても, 大きいバッチサイズでは汎化性能が下がる – 並列化に向けて課題 • 経験的に対数的な学習の進みであることから ランダムポテンシャルのランダムウォークでモデル化 できると仮定 • モデル化によって, – 学習率とバッチ正規化を調整でき,一般化ギャップを 5%から1~2%改善 – 汎化性能はtrain lossやvalid lossの目に見える変化がなく なっても重みの初期値からの距離で見ることができ,一 般的な実践,理論的な推奨と対照的 – 反復回数させ増やせばsmall batchと遜色なく一般化でき る 3

4.

ラージバッチ訓練 • 学習法 – Adam,Rmsprop,AdagradよりSGDの単純な変種がよく使われ ている,汎化性能が高くなるため – そのためHeらと同様に,ResNet44で学習率を指数関数的に減 少させた momentum SGDを使用した • 既存研究の経験的観察 – Keskarらのlarge batchの観察 • 図1のように一般化ギャップがある • 損失関数が改善しなくなるまで学習しても消えない • 一般化の高低はミニマムの鋭さと相関していた – 汎化性能の高いミニマムのまま,パラメータ空間での鋭さを変えられる縮退 方向(ニューロンの入出力の単位を変えるなど)が考慮されていない • mini batchの方が初期値から離れられる – Keskarらの仮説 • small batchでは推定雑音によって,鋭く一般化できないミニマムから 脱出できる 4

6.

理論解析:記法 • サンプルごとに定義される損失関数を使って全体の 損失関数を表す • サンプル毎の勾配をミニバッチ平均して推定値 • 勾配推定は平均が真の値で,バッチ間で非相関 – 並び替えでバッチを作るなら完全に,そうでなくてもおおよ そ(付録A) • 物理に例えるとランダムウォークする粒子と慣性で 考えられる 6

7.

理論解析:ガウス型確率場と深層学習 • ポテンシャルをランダム過程にして考える – ガウス型確率場は時空間に広がるガウス過程 • 時空間はパラメータ空間と学習時刻 • 確率変数はバッチによって変動する損失関数 – ローカルミニマムの損失が悪くないと示すのに 平均0,自己共分散がパラメータのユークリッド距 離で決定論的に表せるガウス型確率場を使用 – 損失の高いローカルミニマムが指数的に減ると いう仮説も,おなじモデルで示された • DNNでも現実的なニューラルネットでも証明され始め ている 7

8.

理論解析:ランダムポテンシャル上のランダムウォーク • 自己共分散が距離のα乗に比例するなら, 初期値からの距離平均が対数で近似できる • 平らなポテンシャル上の標準的な拡散 𝑡 に対し, これは「超低速拡散」 • 距離のα乗に比例した高さのポテンシャルが現 れる – 到着にかかる時間はポテンシャルと指数的な関係が あるので,逆に距離は対数で表される 8

9.

理論解析:ランダムポテンシャル(理解してない) • 距離があるほど共分散が大きくなる謎のモデル • 損失の平均は0,分散も0 • 初期値近傍の損失関数がほぼ平坦で,推定の揺れ のほうが多い状況を想定している? 重みの初期値から 離れた場所の ランダムな損失の イメージ? 9

10.

理論解析:実験結果との関係 • 今回の実験でαがいくつなのか,距離の遷移 を調べる • α=2だと判明 • バッチサイズごとに違いがある – 拡散率が定数倍異なる – エポック数で揃えたのでバッチサイズが小さいほ ど長くグラフが続いている • よって,幅が広い最小値ほど?移動に時間 がかかるので,高い拡散率と訓練反復が必 要になる 10

12.

バッチサイズの違う学習間で拡散率をマッチさせる • small batchと重みの増え方を揃える • 学習率 – 付録Aの導出から,重み増分の共分散と学習率とバッチ サイズの関係が近似されるので,バッチサイズの効果を 打ち消すように学習率を設定する – するとランダムウォークの拡散率が揃う – ドリフト値というか平均ステップの増加は一般にオーダー が小さいので無視できると判明 • ミニバッチ勾配推定に乗法ノイズを加えると平均も揃うらしい – 発散を防ぐため,初期反復では勾配のクリッピングまたは 正規化が必要 – dropconnect,ラベル雑音,dropoutでは共分散を揃えるこ とができず,ギャップも減らせず – バッチ正規化ではサンプル勾配がミニバッチに依存する ので話が変わってくる 12

13.

学習率の2乗 重み増分の共分散 バッチサイズ 13 サンプル毎勾配

14.

バッチサイズの違う学習間で拡散率をマッチさせる • ゴーストバッチ正規化 – バッチ正規化はバッチサイズに依存するので,仮想ミニバッチ の統計を使用して,ミニバッチと同等の性能をもたせたい. – 完全なバッチ統計が重要だと実験から判明 – アルゴリズム1は汎化誤差を減少させる – 仮想ミニバッチは分散システムではすでに使われていることが 多いがご利益は知られていない – デバイス内での仮想化はまだ,重み付けも等しいまま – 重みの遷移グラフもおおよそ一致 – 定数のズレがあるが,多分勾配制限のせい,パフォーマンスに 悪影響はないだろう • アルゴリズム – ラージバッチを分割してそれぞれで平均,標準偏差を出す – 分割毎にバッチ正規化 – テスト時の平均,標準偏差の計算はよくわからなかった 14

15.

ラージバッチを分割 それぞれの平均 と標準偏差 それぞれで正規化 15

16.

更新回数をマッチさせる • 次の問題は反復回数 • 初期値からの距離で学習率を計画するという新 しい方法 • Valid errorが安定したとき,訓練誤差を減少させ るとオーバーフィットする恐れがあるので,学習 率を低下させていた • 低下させなくても精度が向上したので,汎化 ギャップは更新回数が原因 • 学習率が一定のエポックの期間をバッチサイズ の比で拡大する • 図3では汎化ギャップは完全になくなっている 16

18.

実験 • 実験設定 – MNIST,CIFAR10,100,ImageNet – Full Connect F1,畳み込みC1,C3,VGG,Resnet44, Wide-Resnet16-4,Alexnet – 学習法は元論文のものとモメンタムSGD – ラージバッチ=4096,スモールバッチ=128(一部256) – 学習率はバッチサイズ比の平方根で修正 – ゴーストバッチサイズは128 – エポックサイズは比を掛けて大きく • 結果 – ミニバッチより良い精度,表1 – ImageNetでは時間がかかるのでまだ良い精度まで は出ていない 18

20.

議論 • 一般化が不十分なシャープミニマの結果を否定し, ラージバッチでも問題ない • ラージバッチで計算時間をどれだけ短縮できるかを考 えていきたい • 学習率のスケジューリングの常識は間違っているかも • 結論 – 学習の初期段階をランダムポテンシャル上のランダム ウォークでモデル化し,重みの距離が対数増加すると近 似できる – ランダムな勾配推定の統計量をバッチサイズ毎に揃える テクニックで,パフォーマンス低下を防ぐことができる – Ghost-BNは訓練時間を増やさずにパフォーマンスを大幅 に向上させることができる 20

21.

付録 • 重み増分の共分散の導出 – バッチに含まれるかの確率変数を置くと導出でき る • ランダムポテンシャルのαの推定 – 図2と同じくらいの距離まで重みをランダムウォー クさせるように長さ[0,10]ぐらいのランダム方向に パラメータを歩かせて距離とコストをプロットした – 標準偏差が比例したのでα=2 – bはビン幅 21