【DL輪読会】Transformers are Sample Efficient World Models

1.2K Views

November 25, 22

スライド概要

2022/11/25
Deep Learning JP
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] DL輪読会:Transformers are Sample Efficient World Models Ryoichi Takase http://deeplearning.jp/ 1

2.

書誌情報 採録:ICLR2023 under review 概要: Discrete autoencoderとTransformerを組み合わせた世界モデルを提案 モデルベース強化学習を用いてAtari100kベンチマークで高性能を発揮 ※注釈無しの図は本論文から抜粋 2

3.

背景 強化学習の課題: 高性能を発揮するが、学習には非常に多くの経験データを必要とする → サンプル効率が悪い 観測データを直接扱うとタスクと無関係な情報の変化で性能が劣化する 例)ゲーム画面の背景画像など → 汎化性能が低い 世界モデル [1] モデルベース強化学習であり、世界モデル内(想像の中)で方策を学習 → 性能向上に十分な回数を試行可能なためサンプル効率が良い 潜在変数空間における状態遷移のモデル化 → タスクの本質を学習することで汎化性能が向上 想像の中で学習するため世界モデルの精度が性能に直結する [1] Ha, David, and Jürgen Schmidhuber. "World models." 2018. 3

4.

研究目的 関連研究:DreamerV2 [2] 世界モデルベースの強化学習アルゴリズム Atari環境でRainbowを上回る性能を発揮 関連研究:Decision Transformer [3] Transformerモデルが自然言語処理の枠組みを超えて強化学習で高性能を発揮 研究目的: Transformerの系列モデリング技術を応用して高精度な世界モデルを構築 得られた世界モデルを用いて高性能なモデルベース強化学習を実現 [2] Hafner, Danijar, et al. "Mastering atari with discrete world models." 2020. [3] Chen, Lili, et al. "Decision transformer: Reinforcement learning via sequence modeling." 2021. 4

5.

提案する世界モデルの概要 提案手法: IRIS (Imagination with auto-Regression over an Inner Speech) Discrete autoencoderとTransformerを組み合わせて世界モデルを構築 次の状態𝑥ො𝑡+1 、報酬𝑟Ƹ𝑡、エピソードの終了 𝑑መ 𝑡 を予測 ①エンコーダ𝐸が初期フレーム𝑥0をトークン𝑧0に変換(実際の環境情報で初期化) ②デコーダ𝐷がトークン𝑧𝑡 を画像𝑥ො𝑡 に再構成 ③方策𝜋が再構成画像𝑥ො𝑡 から行動𝑎𝑡 をサンプリング መ ④Transformerが報酬𝑟、エピソードの終了 Ƹ 𝑑、次のトークン𝑧 𝑡+1 を予測 ④ ① ② ③ 5

6.

Discrete autoencoder ① エンコーダ 𝑬: 入力画像𝑥𝑡 をvocab size 𝑁 のトークンに変換 Convolutional Neural Network (CNN)により入力画像𝑥𝑡 を出力𝑦𝑡 に変換 トークン𝑧𝑡 を𝑧𝑡𝑘 = argmin𝑖 ቛ𝑦𝑡𝑘 − 𝑒𝑖 ቛで選択 (ℰ = 𝑒𝑖 𝑁 𝑖=1 :対応する埋め込み表) ② デコーダ 𝑫: CNNデコーダを用いてトークンを画像𝑥に再構成 ො Discrete autoencoderの学習: 収集したフレームデータを使用 損失関数としてL2 reconstruction、commitment、perceptualを等しく重みづけ ④ ① ② ③ 6

7.

Transformer ④ Transformer 𝑮: Discrete autoencoderで得たトークンを用いて、潜在空間での状態遷移モデルを学習 時刻𝑡までのトークン𝑧≤𝑡 と行動𝑎≤𝑡 に加えて 時刻𝑡 + 1で既に予測した も使用して予測 Transformerの学習: 損失関数としてTransitionとTerminationには交差エントロピー誤差、 Rewardには交差エントロピー誤差もしくは平均二乗誤差を使用 ④ ① ② ③ 7

8.

学習手順 学習ループ→ (A) 環境との相互作用: 実環境で軌跡データを収集して𝒟に格納 環境との相互作用 世界モデルの更新 方策の更新 (B) 世界モデルの学習: (A) → 1. 学習データを𝒟からサンプリング 2. Discrete autoencoderを更新 3. Transformerを更新 (B) → (C) 方策の学習: 1. 初期フレームを𝒟からサンプリング 2. 世界モデル内で経験データを収集 3. 方策・価値関数を更新 (C) → ※目的関数とハイパーパラメータはDreamerV2を参考に設定 8

9.

ベンチマーク環境 Atari100kベンチマーク: 26種類のAtari ゲームで構成 エージェントは各環境で100kステップの行動が可能 → 人間のゲームプレイ約2時間に相当する ゲーム例:Frostbite (左) と Krull (右) 9

10.

ベースラインアルゴリズム 先読み検索の有無でベースラインを区別: IRIS(提案手法)はMonte Carlo Tree Searchとの組み合わせが可能だが、 本論文では先読み検索なしの手法を比較対象として設定 先読み検索なし: SimPLe [5]、CURL [6]、DrQ [7]、SPR [8] 先読み検索あり: MuZero [9]、EfficientZero [10] [5] Kaiser, Łukasz, et al. "Model Based Reinforcement Learning for Atari." 2019. [6] Srinivas, Aravind, Michael Laskin, and Pieter Abbeel. "CURL: Contrastive Unsupervised Representations for Reinforcement Learning." 2020. [7] Yarats, Denis, Ilya Kostrikov, and Rob Fergus. "Image augmentation is all you need: Regularizing deep reinforcement learning from pixels." 2020. [8] Schwarzer, Max, et al. "Data-efficient reinforcement learning with self-predictive representations." 2020. [9] Schrittwieser, Julian, et al. "Mastering atari, go, chess and shogi by planning with a learned model." 2020. [10] Ye, Weirui, et al. "Mastering atari games with limited data." 2021. 10

11.

数値実験の評価方法 文献[11]に従い正規化スコアを用いて評価を実施 正規化スコアの定義: 𝑠𝑐𝑜𝑟𝑒𝑎𝑔𝑒𝑛𝑡 − 𝑠𝑐𝑜𝑟𝑒𝑟𝑎𝑛𝑑𝑜𝑚 ℎ𝑢𝑚𝑎𝑛 𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑧𝑒𝑑 𝑠𝑐𝑜𝑟𝑒 = 𝑠𝑐𝑜𝑟𝑒ℎ𝑢𝑚𝑎𝑛 − 𝑠𝑐𝑜𝑟𝑒𝑟𝑎𝑛𝑑𝑜𝑚 𝑠𝑐𝑜𝑟𝑒𝑟𝑎𝑛𝑑𝑜𝑚 :ランダム方策のスコア 𝑠𝑐𝑜𝑟𝑒ℎ𝑢𝑚𝑎𝑛 :人間プレイヤーのスコア 層別ブーストラップによる信頼区間の推定: 平均値(Mean)と中央値(Median)に加えて、 下位25%と上位25%を除いた残りの50%の平均値(Interquartile mean: IQC)の信頼区間を推定 Performance profileの図示: 正規化スコア以上の割合をグラフ化 [11] Agarwal, Rishabh, et al. "Deep reinforcement learning at the edge of the statistical precipice." 2021. 11

12.

信頼区間に関する結果 IRIS(提案手法)は平均値1.046、IQM値0.501を達成 → 26ゲーム中10ゲームで人間のプレイヤーより高い性能を発揮 12

13.

Performance Profileに関する結果 IRIS(提案手法)はベースラインと同等以上の性能 正規化スコアを超える割合が0.5以下の場合は他手法よりも高性能 → Atari100kベンチマークで先読み検索を使用しない最先端技術であることを示唆 高性能 スコアが0以上の割合が100% 低性能 スコアが1以上の割合が約30%(IRISが最も高性能) グラフの見方: 縦軸:正規化スコア以上の割合 横軸:正規化スコア 上にある曲線ほど優れた手法であることを意味 13

14.

実験結果 Pong、Breakout、Boxingのような分布シフトの影響が小さいゲームで特に高性能を発揮 14

15.

実験結果 FrostbiteとKrullのようなサブゲームを段階的にクリアするゲームでは性能を発揮できない場合がある 15

16.

FrostbiteとKrullの結果の考察 Frostbiteで低性能となった考察: 最初のレベルを終了するには、イグルー構築後に画面下部からイグルーに戻るという 稀でかつ一連の長い行動が必要 → 稀な事象は想像上で十分に経験できないため性能が低くなる Krullで高性能となった考察: 次のステージへの移行が頻繁に行われる → 世界モデルがゲームの多様性をうまく反映できたため想像上でも十分に経験できた Frostbite (左) と Krull (右)の3 つの連続レベル 16

17.

世界モデルの性能解析 想像の中で方策を学習するため世界モデルの精度が性能に直結する → 世界モデルの精度を生成画像から確認 性能評価のポイント: Discrete autoencoder: ボール、プレイヤー、敵などの要素を正しく再構成しているか? Transformer: ゲームの重要な仕組み(報酬やエピソード終了)を正しく捉えているか? IRIS(提案手法)の世界モデルの性能解析を以下のケースで実施 KungFuMaster、Pong、BreakoutとGopher 17

18.

KungFuMasterでの性能解析 各シミュレーションで様々な状況(敵の数など)を生成 青枠からプレイヤーに攻撃された敵は姿を消していることが確認できる → 世界モデルはゲームの重要な仕組みを捉えている 4つの軌跡例 シミュレーション開始点 (実環境の情報で初期化) 世界モデルの想像結果 18

19.

Pongでの性能解析 世界モデルはボールの軌道と選手の動きを捉えている 青枠から勝者側のスコアボードが更新されていることが確認できる → ピクセル単位で高精度な予測を実現 実際の結果 → 世界モデルの 生成結果 → シミュレーション開始点 (実環境の情報で初期化) 19

20.

BreakoutとGopherでの性能解析 ゲームの仕組みを高精度に予測 Breakout: 黄枠:レンガを壊すと報酬が得らる 赤枠:ボールを逃すとエピソードが終了 Gopher: 黄枠:穴をふさぐかモグラを倒すと報酬につながる 赤枠:モグラが人参に到達するとエピソードが終了 黄枠:世界モデルが正の報酬を予測するフレーム 赤枠:エピソード終了のを予測しているフレーム 各行は実環境の情報で初期化し、残りの軌道を想像させた結果 20

21.

まとめ IRIS (Imagination with auto-Regression over an Inner Speech): Discrete autoencoderとTransformerを組み合わせた世界モデルを提案 実験結果: Atari100kベンチマークで高性能を発揮 世界モデルはゲームの重要な仕組みを捉えて高精度な予測を実現 → 先読み検索を使用しない手法として最先端技術であることを示唆 21