【DL輪読会】Prompting Decision Transformer for Few-Shot Policy Generalization

530 Views

September 30, 22

スライド概要

2022/9/30
Deep Learning JP
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] 論 文 解 説 : Prompting Decision Transformer for Few-Shot Policy Generalization Ryoichi Takase http://deeplearning.jp/ 1

2.

書誌情報 採録:ICML2022 概要: オフラインメタ強化学習において、ゼロ・少数ショット学習で未知のタスクに適応する手法を提案 Decision Transformerの枠組みへの軌跡プロンプトの導入により、 パラメータ更新を行うことなく未知のタスクへ適応し、高性能を発揮することを示した ※注釈無しの図は本論文から抜粋 2

3.

オフライン強化学習 (オンライン)強化学習: 現在の方策を用いて環境と相互作用し、経験データを収集して方策を学習 状態・報酬 環境 方策 行動 オフライン強化学習: 環境と相互作用せずに、過去の経験データのみを用いて最適な方策を学習 → 環境との相互作用が難しい分野(例、医療・ヘルスケア)への応用が期待されている 状態・行動・報酬 方策 オフラインデータセット 3

4.

オフラインメタ強化学習 オフライン強化学習の課題: 経験データに含まれるタスクのみから学習 → 未知のタスクに対する性能向上が課題 課題解決のために、オフラインメタ強化学習が提案されている オフラインメタ強化学習 [1]: 各タスクの経験データのみを用いる問題設定であり、 未知のタスクに対して少数データで適応できる方策を学習する 図は文献[1]より抜粋 [1] Mitchell, Eric, et al. "Offline meta-reinforcement learning with advantage weighting." International Conference on Machine Learning. PMLR, 2021. 4

5.

研究目的 本研究では、オフラインメタ強化学習の問題設定に自然言語処理の観点からアプローチする 関連研究:プロンプト [2] ゼロ・少数ショット学習で新しいタスクに適応するために、 プロンプトを用いたフレームワークが提案されている タスクの説明といくつかの例を入力の接頭辞として付加することで、 大規模言語モデルのパラメータを更新せずに新しいタスクに適応させる 関連研究:Decision Transformer [3] Transformerモデルが自然言語処理の枠組みを超えて、オフライン強化学習で高性能を発揮 研究目的: 自然言語処理のプロンプトのフレームワークを応用し、オフライン強化学習の未知タスクに対して、 パラメータ更新のないゼロ・少数ショット学習を実現したい [2] Brown, Tom, et al. "Language models are few-shot learners." Advances in neural information processing systems 33 (2020): 1877-1901. [3] Chen, Lili, et al. "Decision transformer: Reinforcement learning via sequence modeling." Advances in neural information processing systems 34 (2021): 15084-15097. 5

6.

問題設定 やりたいこと: 𝒯 𝑡𝑟𝑎𝑖𝑛の経験データで学習後、少数のデモンストレーションで𝒯 𝑡𝑒𝑠𝑡 のタスクに適応する 𝒯 𝑡𝑒𝑠𝑡のタスクに適応する際はパラメータ更新を行わない 𝒟𝑖 𝒫𝑖 𝒫𝑖 記号の説明: 𝒯:タスクの集合 添え字𝑖は各タスク𝒯𝑖 ∈ 𝒯を意味 学習タスク 𝒯 𝑡𝑟𝑎𝑖𝑛 テストタスク 𝒯 𝑡𝑒𝑠𝑡 互いに素 𝒟𝑖 :学習データセット 各学習タスク𝒯𝑖 に対応する経験データ(オフライン強化学習のデータセット) 𝒫𝑖 :少数のデモンストレーション 学習タスク𝒯 𝑡𝑟𝑎𝑖𝑛に対しては、𝒟𝑖 の一部分をサンプリング テストタスク𝒯 𝑡𝑒𝑠𝑡に対しては、人間やエキスパート方策によって取得 6

7.

軌跡プロンプト 軌跡プロンプト: 少数のデモンストレーション𝒫𝑖 からサンプリング ⋆:プロンプトであることを明記 𝑟: Ƹ reward-to-go(現在のステップからエピソード終了までの累積報酬) 𝑠:状態 𝑎:行動 𝐾 ⋆:ステップ長 注)ステップ長が短い(実験では2~40ステップ)ため、模倣学習には使用不可 学習の安定性向上と過学習防止のため、確率的な軌跡プロンプトを導入 エピソード1:(𝑟1 , 𝑠1 , 𝑎1 , 𝑟2 , 𝑠2, 𝑎2 , … , ) エピソード2:(𝑟1 , 𝑠1 , 𝑎1 , 𝑟2 , 𝑠2, 𝑎2 , … , ) エピソード3:(𝑟1 , 𝑠1 , 𝑎1 , 𝑟2 , 𝑠2, 𝑎2 , … , ) ⋮ 𝐻ステップ → ステップ長 𝐾 ⋆ = 𝐽𝐻 𝐽エピソード 7

8.

ネットワーク構造 モデル構造: Decision Transformerと類似 - 大規模言語モデルGPTの縮小版 入力データ: 𝜏 𝑖𝑛𝑝𝑢𝑡 = (𝜏𝑖⋆, 𝜏𝑖 ) 𝜏𝑖⋆:𝐾 ⋆ステップの軌跡プロンプト(𝒫𝑖 から取得) 𝜏𝑖 :直近𝐾ステップの軌跡の履歴(𝒟𝑖 から取得) ⋆ + 𝐾) 1ステップのデータは(𝑠, 𝑎, 𝑟)で1セットなので入力データ長は3(𝐾 Ƹ 8

9.

学習手順 環境と相互作用せずに、オフラインデータから方策を学習 ①履歴𝜏をサンプリング ②プロンプトをサンプリングして𝜏 ⋆取得 ①→ ②→ → 入力データ:𝜏 𝑖𝑛𝑝𝑢𝑡 = [< 𝑝𝑟𝑜𝑚𝑝𝑡 >, 𝑠1 , 𝑎1 , 𝑟1 , 𝑠2 , 𝑎2 , 𝑟2 , … ] ③学習の安定化のために、 バッチデータℬには全ての学習タスクのデータを含める ③→ ④→ ④行動予測誤差を最小化するように勾配降下法を用いて学習 軌跡プロンプトからタスクの情報を把握し、 履歴と組み合わせて次の行動を予測するように学習する 9

10.

テスト手順 環境と相互作用するオンライン環境で評価 ①各エピソードの最初に履歴𝜏を初期化 ②学習手順と同様にプロンプトをサンプリング ③プロンプトと直近の履歴を入力として受け取り行動を生成 ④データを集めながら𝜏をアップデート 入力データ: 𝜏 𝑖𝑛𝑝𝑢𝑡 = [< 𝑝𝑟𝑜𝑚𝑝𝑡 >, 𝑠1 , 𝑎1 , 𝑟1 ] 𝜏 𝑖𝑛𝑝𝑢𝑡 = [< 𝑝𝑟𝑜𝑚𝑝𝑡 >, 𝑠1 , 𝑎1 , 𝑟1 , 𝑠2 , 𝑎2 , 𝑟2 ] ⋮ ①→ ②→ ③→ ④→ 軌跡プロンプトからタスクの情報を把握できるため、 未知のタスクでも適切な行動を決定する 10

11.

環境とデータセット 環境: Cheetah-dir(タスク数2個): 目標方向(前後)に進むタスク Cheetah-vel(学習タスク35個、テストタスク5個): 目標速度(一様分布により決まる)で進むタスク Ant-dir(学習タスク45個、テストタスク5個): 目標方向(一様分布により決まる)に進むタスク Dial(学習タスク6個、テストタスク4個): 6-DOFのロボットを制御するタスク Meta-World reach-v2(学習タスク15個、テストタスク5個): 3次元空間でロボットを目標位置に制御するタスク データセット: Cheetah-dir、Cheetah-vel、Ant-dir: → 文献[1]のデータセットを使用 DialとMeta-World reach-v2: → 熟練方策によってデータを収集 [1] Mitchell, Eric, et al. "Offline meta-reinforcement learning with advantage weighting." International Conference on Machine Learning. PMLR, 2021. 11

12.

ベースラインアルゴリズム Prompt-DT(提案手法)を以下4つのベースラインと比較 Multi-task Offline RL (MT-ORL): トレーニングセットのマルチタスクで学習 Prompt-based Behavior Cloning (Prompt-MT-BC): トレーニングとテスト時にreward-to-goトークンを除外 → reward-to-goトークンの効果を確認するために使用 Multi-task Behavior Cloning (MT-BC-Finetune): プロンプトとreward-to-goトークンの両方を除外し、目標タスクのデータを用いてファインチューニング → プロンプトとreward-to-goトークンの効果を確認するために使用 Meta-Actor Critic with Advantage Weighting (MACAW): オフラインメタ強化学習の手法で、サンプル効率が高いアルゴリズム 12

13.

Prompt-DTの性能評価 実験結果: Prompt-DT(提案手法)はベースラインよりも高性能を発揮 Reward-to-goトークンの効果: Prompt-DTとPrompt-MT-BCは、Dialタスク以外では同程度の性能 → プロンプトにはタスク特定に十分な情報が含まれているが、 Dialタスクのようにプロンプト自体が不十分な場合はreward-to-goトークンが学習を助ける プロンプトとreward-to-goトークンの効果を比較: Prompt-MT-BCの方がMT-ORLより高性能を発揮 → reward-to-goトークンよりもプロンプトの方がタスクを特定するのに有効 13

14.

軌跡プロンプトの量に関する結果 エピソード数𝐽とステップ数𝐻を変化させ、プロンプト長𝐾 ⋆の影響を考察 実験結果: Prompt-DTは、プロンプトの量に依存しない → 少ないステップ数でもタスク固有の情報を特定することが可能 14

15.

軌跡プロンプトの質に関する結果 学習データセット𝒟𝑖 と少数デモンストレーション𝒫𝑖 のデータの質を変えた場合を検証 𝒟𝑖 がexpert・medium・randomの3通り 𝒫𝑖 がexpert・medium・randomの3通り → 3×3=9通りを検証 実験結果: プロンプトがexpert・mediumであれば、学習データセットの質がmediumであっても最適な方策が得られる プロンプトがrandomの場合は学習データセットがexpertでも最適な方策は得られない 15

16.

分布外のタスクに関する結果 テストタスクの目標値を学習タスクの目標値の範囲内ではなく、 範囲外のタスク(分布外タスク)に設定して性能を検証する Ant-dir(学習タスク8個、テストタスク3個): 3個中2個のテストタスクで目標値が学習タスクの範囲外 実験結果: Prompt-DTは他手法と比較して高性能を発揮 → 軌跡プロンプトが分布外のタスクに対して有効であることを示唆 16

17.

まとめ Prompt-DT: オフラインメタ強化学習の問題設定において、 Decision Transformerの枠組みに軌跡プロンプトを導入 → パラメータ更新を行うことなく未知のタスクへの適応を可能とした 実験結果: ベースラインアルゴリズムと比較して高性能を発揮 分布外のタスクに対しても高性能を発揮 17