152 Views
December 25, 18
スライド概要
2019/01/11
Deep Learning JP:
http://deeplearning.jp/seminar-2/
DL輪読会資料
DEEP LEARNING JP [DL Papers] Attentive Neural Processes Hirono Okamoto, Matsuo Lab http://deeplearning.jp/ 1
書誌情報: Attentive Neural Processes ◼ ICLR 2019 accepted ◼ 著者: Hyunjik Kim(前回輪読会で自分が発表したdisentangling by factorisingと一緒) ◼ Reviewer1 (rating 6, confidence 4) ◼ NPのunder-fittingの本当の原因の分析が不足している ◼ technical detailsが欠けているので再現が難しい ◼ ← Appendix Aとfigure 8に詳細の構造を載せた(著者) ◼ Reviewer2 (rating 6, confidence 4) ◼ NPを改良しているものの,貢献が大きくない ◼ ← 単純な改良だが,NPの欠点をなくしており,貢献は大きいのでは(著者) ◼ Reviewer3 (rating 7, confidence 4) ◼ cross-attentionがANPの予測分散を小さくしたというのは直感的 ◼ self-attentionとcross-attentionのablation studyがみたい ◼ ← 1次元回帰においてはcross-attentionしか使ってないからcross-attentionだけでも性能があがることは示 されている(著者)
論文概要 ◼ Attentive Neural Process (ANP)は,Neural Process (NP)がunderfittingである問題をAttentionの 枠組みを用いることによって解決したモデルである ◼ Neural Process (NP)とは Deep neural network (DNN) のように関数を万能近似できるため,高い表現能力がある ◼ Gaussian Process (GP) のように事前分布を活用し,関数の事後分布を推論できる ◼ ◼ 実験では,非線形回帰・画像補完で不確かさを含むモデリングができることを示した Pros Cons Deep Neural Network 高い表現能力がある 推論時のスケーラビリティがある 事前分布の活用が難しい データが大量に必要 Gaussian Process 不確実性のモデリングができる 事前分布の活用ができる データ数をnとして,訓練時にO(n^3), 推論時にO(n^2)の計算時間がかかる
背景: Neural Process (NP)とは ◼ 一般の教師あり学習は関数の背後にあるデータの関数f をgで近似している(右図 a,b) ◼ 例えば,パラメトリックな関数gを用意して,パラメータ を初期化し,フィッティングする ◼ 事前知識はgの構造や損失関数にいれることができるが, 事前知識の表現は限られてしまう ◼ 一方,NPは,観測データのembeddingを求め,それら を足しあわし,条件付けとする(右図 c) ◼ GPのような方法では,データ数にスケーラビリティがな いため,計算時間が非常にかかってしまう(O(n+m)^3) ◼ NPの良い点 ◼ Scalability: 訓練・予測時間はO(n+m) ◼ Flexibility: いろんな分布を定義できる ◼ Permutation invariance: データ点の順番に不変 n m
背景: Neural Process (NP)の問題とその解決策 ◼ しかし,NPはアンダーフィットしてしまう問題 がある ◼ 不正確な予測平均・大きく見積もられた分散 ◼ 単純にcontext情報を平均していることが問題であ ると仮定 ◼ それぞれの点において,同じ重みを与えていること になるので,デコーダgがどの点が関係する情報を 与えているかを学習するのが難しくなる ◼ アテンション機構を使って上記問題を解決する ◼ GPのように,新しい入力xと訓練データx_iが近け れば予測出力yと訓練データy_iも近い値になるよ うにする ◼ NPと同じく,permutation invarianceも保存される 問題の構造?
背景: Attentionとは ◼ 要素: key(k), value(v), query(q) ◼ 入力qに対して,類似するkを索引し,対 応するvを取り出す ◼ qはkey-valueのペアの順番に対して不変 ◼ ANPで使われる3つのAttention機構 ◼ Laplace Attention ◼ DotProduct Attention ◼ MultiHead Attention DotProduct MultiHead 図引用: http://deeplearning.hatenablog.com/entry/transforme
提案手法: ANPのNPからの変更点 ◼ 入力(x, y)をconcatし,Self-attentionを行う ◼ 訓練データ同士の相互作用をモデリングできる ◼ 例えば,複数の訓練データが重なった場合,queryはすべての点に注目する必要はなく, 少ない点に大きな重みを与えるだけでよい ◼ NPのaggregationをCross-attension機構に置き換える ◼ 新しいデータ点(query)は,予測に関係のある訓練データ(例えば場所が近い点)に注目するようになる 変更箇所
提案手法: より具体的な構造 ◼ Self-attention ◼ 入力: x, yのconcat ◼ 出力: r ◼ Cross-attention ◼ 入力: r(value), x(key), x*(query) ◼ 出力: r* 図引用: http://deeplearning.hatenablog.com/entry/transformer
実験: 1次元回帰(NP vs ANP) ◼ 実験設定: ◼ ANPはself-attentionは使わず,cross-attentionのみを使っている ◼ NPはbottle neck(d)を128, 256, 512, 1024と変えて実験した ◼ 結果: ◼ ANP,特にdot productとmultiheadの収束がiterationでも時計時間でも早かった ◼ NPはdを大きくすれば性能がよくなったが,再構成誤差は途中で頭打ちになった ◼ ANPの計算時間はO(n+m)からO(n(n+m))に増えるが,訓練が収束する時間はむしろ短くなる underfit気味だが, なめらか context error target error epoch 時計時間 underfitしてないが, GPのようになめらかで,context点が なめらかでない 遠い場所では不確かさが増加している
実験: 1次元回帰(GPとの比較) ◼ NPよりもMultihead ANPのほうがGPに近い ◼ しかし,varianceをunderestimateしていることがわかる ◼ 一つの理由として変分推論が予測分散をunderestimateしていることが考えられる
実験: 2次元回帰(画像補完) ◼ 入力: 画像位置x, 出力: ピクセル値y, データ: CelebA(32x32) ◼ ピクセルの場所と値(x, y)をいくつか与えたとき,残りのピクセル値を予測するタスク ◼ それぞれの生成画像は, から3つサンプルし, の平均に対応する ◼ 定性的にも定量的にも,Stacked Multihead ANPはNPよりも正確な画像を出力した
実験: 2次元回帰(画像補完) ◼ 入力: 画像位置x, 出力: ピクセル値y, データ: MNIST ◼ CelebAのときと同様に,ANPの方が定量的に良い結果 ◼ NPはすべての点が与えられても予測分散の値が減っておらず,予測分散をoverestimateして いるといえる(下図赤枠) NP ANP
実験: 2次元回帰(画像補完・Multihead ANPの分析) ◼ 半分画像を隠したとき,残りの画像を予測させるタスク ◼ 見たことがない画像にも汎化した ◼ バツのtarget点が与えられたとき,Multihead ANPのheadがど こを注目しているかを色でわけた(右図) ◼ それぞれのheadに役割があることがわかる
実験: 2次元回帰(解像度変更) ◼ 画像を別の解像度の画像にするタスク ◼ 32x32の画像で訓練したANPは4x4の画像と8x8の画像それ ぞれの画像の解像度を32x32まであげることを可能にした
付録 ◼ 関連研究 ◼ CNP ◼ NP ◼ 再現実装(GP・NP・ANP)
関連研究 Conditional VAE (Sohn, 2015) NPと異なり,xの条件付けが存在しない.応用例を考えると,画像の位置による違いの条件 付けができないため,画像補完はできないということになる.また,CNPと同じように, globalな変数は存在せず,それぞれの画像にたいしてローカルな潜在変数zが存在する. Neural Statistician (Edwards, 2016) CVAEに対し,globalな変数zを考慮したモデル.global変数zを使ってローカル変数zをサン プリングできるため,yの値の分布を生成できる.しかし,CVAEと同様にxの条件付けが存 在しないため,GPやCNPのようにx,yの関係を捉えることができない. Conditional Neural Processes (Garnelo, 2018) Context点(x, y)から得られるrの和と新たなデータ点x*を条件として,yを予測するモデル. Globalな潜在変数が存在しないため,y1, y2, y3のようなそれぞれの分布は出力できるもの の,y1, y2, y3それぞれを一つのまとまりとしてサンプリングできない. Neural processes (Garnelo, 2018) CNPと第一著者は同じ.CNPでは,globalな潜在変数が存在しないため,同じcontextの データを条件としたとき,y1, y2, y3, …のようなそれぞれの値ごとにしかサンプリングがで きない.一方,NPでは,contextで条件づけたglobalな潜在変数が存在するため,y1, y2, y3…を同時にサンプリング,つまり,関数のサンプリングが可能. 画像の場合,xは位置,yはピクセル値
Conditional Neural Processの訓練 ◼ モデル: ◼ ノーテーション ◼ 観測データ O = {(x_i, y_i)}_{i=0}^{n-1} ⊂ X x Y ◼ ターゲットデータ T = {x_i}_{i=n}^{n+m-1} ⊂ X ◼ f: X → Y ◼ 目的: P(f(T) | O, T)をNNを使ってパラメトリックにQ_θでモデル化 ◼ Q_θのモデル化 ◼ MLPのh_θによるembeddingでrを求める ◼ それぞれのrを足し合わせる ◼ rで条件づけたときの新しい入力点からパラメータを求める ◼ Q_θをパラメータφでモデル化する ◼ Q_θの訓練 ◼ Oの部分集合O_NからOを予測するように学習する(n > N) ◼ Nと訓練データをランダムに選ぶ ◼ 勾配法などでQ_θの負の対数尤度の最小化を行う
Conditional Neural Processの実験結果 ◼ 一次元回帰の実験 ◼ aはGPとの比較で,赤がGP,青がCNPの予測 ◼ bは異なるデータセットで異なるカーネルパラメータでの CNPの予測 ◼ GPのほうがなめらかに予測できているものの,CNPは GPと同様に,不確かさをモデリングできており,デー タ点が少ないところでは不確かさが大きくなっているこ とがわかる
Conditional Neural Processの実験結果 ◼ 画像補完(MNIST) ◼ x: 画像のピクセルの座標を[0, 1]^2に正規化したもの ◼ y: ピクセルの値 [0, 1] ◼ 画像の観測点が増えるにつれて,ground truthに近づくことがわかる(画像a) ◼ 不確かさが大きい点の情報から与えていくと,対数尤度が早く大きくなることがわかる(画像b)
Conditional Neural Processの実験結果 ◼ 画像補完(CelebA) ◼ x: 画像のピクセルの座標を[0, 1]^2に正規化したもの random context ◼ y: ピクセルの値 [0, 1]^3 ◼ 画像の観測点が増えるにつれて,ground truthに近づくこと がわかる(画像上) ◼ 未知の画像の半分が隠されていても,残りの画像を予測する ことができる.すなわち,顔は対称的である・顔の下には口 と鼻があるといった全体的な特徴を学習している(画像下) ◼ これはGPでは捉えきれない特徴である ◼ 定量的にも,contextが少ない場合に特に,与えられた点が randomであってもorderedであってもCNPはMSEが小さいこ とが示された(下表) ordered context
Neural Processの訓練 ◼ CNPと異なるのはrからzを正規分布に従って サンプリングする点のみ ◼ ELBO最小化を行う ◼ nはすべての訓練データ ◼ mはtarget点 ◼ (注) CNPと同様に,訓練データをcontextと targetに毎回ランダムに分割して,訓練する
Neural Processの実験結果 ◼ 一次元回帰 ◼ 訓練データが多くなるほど不確か さが小さくなっている ◼ (ANPと比べるとやはり不確かさを 大きく見積もってるようにみえる)
Neural Processの実験結果 ◼ 二次元回帰 ◼ CNPと異なり,sample画像はcontext点が少なくてもぼやっとならずに,いろんなラベ ルのサンプルが出力される
Neural Processの実験結果 ◼ ベイズ最適化 ◼ トンプソンサンプリングを行い,次に探索する点を決定する ◼ ランダムサーチよりも早く最適化されることがわかった
再現実装(GP・NP・ANP) ◼ https://qiita.com/kogepan102/items/d03bc2f0819cbf550e8d GP事前分布 GP事後分布