【DL輪読会】Controlling Large Language Model with Latent Actions

357 Views

January 15, 26

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

ダウンロード

関連スライド

各ページのテキスト
1.

Controlling Large Language Model with Latent Action Yuya IMAI, Matsuo Iwasawa Lab | 1 of 17

2.

書誌情報 タイトル: Controlling Large Language Model with Latent Action 会議: ICML 2025(Poster) 著者: Chengxing Jia, Ziniu Li, Pengyuan Wang, Yi-Chen Li, Zhenyu Hou, Yuxiao Dong, Yang Yu TL;DR: LLMの行動をトークンではなく少数の離散潜在アクションで制御するCoLAを提案し、探索空間を圧縮して RLを効率化する。数学推論やエージェントタスクで性能向上し、Reward Hackingにも比較的頑健。 リンク: https://openreview.net/forum?id=cEKrGCFXPA https://proceedings.mlr.press/v267/jia25e.html https://arxiv.org/abs/2503.21383 https://github.com/LAMDA-RL/CoLA https://huggingface.co/LAMDA-RL/Llama-3.1-CoLA-10B | 2 of 17

3.

背景と課題 LLMを特定のタスクに適応させるために、強化学習(RLHF / RLAIF など)を使う手法は一 般的だが、従来の定式化には課題がある 課題1: 行動空間が巨大すぎる(探索が非効率) LLMが次に出力する各トークンをそのまま行動として扱う 近年のモデルは語彙数が非常に大きい(例: Llama-3 系は 12万語彙以上) → 1ステップの分岐が大きく、探索・クレジット割当が難しいため サンプル効率が悪い 課題2: 構造の欠如 LLMは本来「次トークン予測(next-token prediction)」の生成モデルであり、RLエージェントとして設計されてい ない そのため、報酬に合わせて望ましい振る舞いを安定に制御しづらい | 3 of 17

4.

提案手法: CoLA (Controlling LLMs with Latent Actions) 離散潜在変数を用いた階層的な条件付き生成モデルに拡張し、強化学習の探索空間を削減 主な構成要素: 1. Language World Model (fworld ) 2. Policy Model (π) 3. Inverse Dynamics Model (finverse ) ​ ​ Naive Decoder-only Pipeline CoLA Pipeline | 4 of 17

5.

フレームワーク詳細 1. Language World Model (fworld ) ​ 入力: 過去コンテキスト x1:t , 潜在アクション at Merge moduleを使って、事前学習済みLLMの埋め込みに、潜 在アクションを埋め込みとして注入 出力: 次トークン xt+1 の分布 ​ ​ ​ 2. Policy Model (π) 入力: 過去コンテキスト x1:t 出力: 潜在アクション at の分布 ここをRLで学習することで、LLM本体を大きく変更せずに制御 ​ ​ | 5 of 17

6.

フレームワーク詳細(続き) 3. Inverse Dynamics Model (finverse ) ​ 入力: 過去コンテキスト x1:t と次トークン xt+1 出力: (その遷移を説明する)潜在アクション at の分布 教師なし学習で潜在アクションを抽出するために使用 サイズN のコードブック C (実験ではN = 64)を使用 ​ ​ ​ 全体の推論プロセス Step 1: at ∼ π(⋅∣x1:t ) Step 2: xt+1 = fworld (x1:t , at ) ​ ​ ​ ​ ​ ​ | 6 of 17

7.

学習プロセス 更新対象の整理 (再掲) fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ Inverse Dynamics Model: θinverse Language World Model: θworld = (θbase , θmerge ) Policy Model: θpolicy ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ 全体像 段階1: 潜在アクション制御の構築(更新: θinverse , θmerge , θpolicy ) 段階2: 指示追従のための世界モデル調整(更新: θbase ) 段階3: 潜在アクション空間でのRL(更新: θpolicy ) ​ ​ ​ ​ ​ | 7 of 17

8.

学習プロセス: 段階1 Constructing Latent Action Control (再掲) ​ 1-1. Inverse Dynamics Model(+Merge module)の学習 Inverse Dynamics Model(finverse )で潜在アクション at を推定 x1:t と at から Language World Model(fworld ) で xt+1 を予測し、予測誤差を最小化 更新: θinverse , θmerge / 凍結: θ^base ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ ​ 1-2. Policy ModelをBehavior Cloning(BC)で初期化 目的: RL前に安定な初期Policy Modelを作る Policy Model(π)で潜在アクション at を推定 Inverse Dynamics Model(finverse )の at を擬似ラベルにして予測誤差を最小化 ​ ​ 更新: θpolicy / 凍結: θ^inverse ​ ​ ​ | 8 of 17

9.

学習プロセス: 段階2 Fine-Tuning under Action Guidance (再掲) ​ 目的: Language World Model(fworld )を指示追従データに適応させる。 Inverse Dynamics Model(finverse )を固定して、出力 at を使用する。 x1:t と at から Language World Model(fworld ) で xt+1 を予測し、予測誤差を最小化 ​ ​ ​ fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ 更新: θbase / 凍結: θ^inverse , θ^merge base更新後にPolicy Modelを再度Behavior Cloningで更新 baseによる埋め込みが変わるため ​ ​ ​ | 9 of 17

10.

学習プロセス: 段階3 Latent Action Reinforcement Learning(RL) (再掲) fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ 与えられるもの: prompt-only データ Drl = {x1:p } と報酬モデル R(x1:T ) 固定: θ^world , θ^inverse (生成の言語能力は世界モデル側に保持) 更新: θpolicy (潜在アクション選択のみ学習) Roll-out(生成) ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ xt+1 ∼ pworld (⋅∣x1:t , at , θ^world ) at ∼ πθpolicy (⋅∣x1:t ), ​ ​ ​ ​ ​ ​ ​ ​ ​ 目的関数 max E[R(x1:T )] ​ θpolicy ​ ​ 実装上は 初期Policyを参照モデルとして、潜在アクション空間でKLを計算して制約・正則化 RL更新はPPO以外にも、GRPO / RLOO / ReMax / REINFORCE++ などのLLM向け手法が選択肢 | 10 of 17

11.

実験結果: 意味的多様性の向上 検証内容: 潜在アクションによる制御が生成テキストの多様性に どう影響するか 検証データから 複数の prefix(過去コンテキスト)をランダ ムに選び、各prefixに対して複数の生成結果を得る BGE-M3埋め込みの cos類似度の総和の逆数 結果 (Fig 2): Latent Action Sampling (青) が高い多様性を示す Random action sampling:latent action をランダムにサ ンプルして world model で生成 Base model sampling:ベースLLM(Llama-3.1-8B)で 通常生成 Random token sampling:トークンをランダムにサンプル して生成 事前学習トークン数が増えるほど多様性は向上 (赤) | 11 of 17

12.

実験結果: 数学推論タスク 検証内容: 数学推論タスクでの性能 結果 (Fig 3): CoLA は Baseline (Llama-3.1-8B SFT) を上回る性能 特にPass@Kにおいて探索能力の高さが示された。RL後、Math500で 42.4 (Baseline 38.2) を達成 Benchmarks Pass@K on Math500 | 12 of 17

13.

効率的な探索: Action-level MCTS MCTS-Q: 潜在アクション空間上でのモンテカルロ木探索 (MCTS) 潜在アクション空間が小さいため、トークン単位よりも探索が 効率的 Q関数(Qwen-Math-2.5-72B reward model)に基づく枝刈り を導入 結果: Math500において、MCTS-Q (CoLA) は 68.2 を達成 Baseline + MCTS 63.2(Baseline + MCTS-Q 63.0、 CoLA + MCTS 65.4)を上回る Math500 Score Comparison Baseline (SFT): 38.2 CoLA (RL): 42.4 Baseline + MCTS: 63.2 Baseline + MCTS-Q: 63.0 CoLA + MCTS: 65.4 CoLA + MCTS-Q: 68.2 | 13 of 17

14.

エージェントタスク: Countdown Game タスク: 与えられた数を使って目標値を計算する。思考過程( <think> )と回答( <answer> )のフォーマット厳守。 結果 (Fig 4): CoLAはBaselineよりも早くフォーマット報酬を獲得(約2倍速) ただしこの設定では正答率は10–15%程度と限定的で、正しく解くのは難しい Reward Curve Response Length | 14 of 17

15.

エージェントタスク: Alfworld & Scienceworld マルチターンRLタスクでの性能検証 結果 (Table 1): CoLA-RLはBaseline-RLと比較して、Seen/Unseenタスクの両方で大幅な性能向上 複雑な環境での探索と適応において優位性を示す BENCHMARK ALFWORLD (Seen) BASE-SFT BASE-RL CoLA-FTA 68.6 68.6 (+0.0) 75.7 ALFWORLD (Unseen) 67.9 71.6 (+3.7) 70.9 CoLA-RL 77.9 (+2.2) 74.6 (+3.7) SCIENCEWORLD (Seen) 17.0 18.0 (+1.0) 24.7 SCIENCEWORLD (Unseen) 17.5 15.6 (-1.9) 20.4 28.4 (+3.7) 21.8 (+1.4) | 15 of 17

16.

Reward Hackingへの頑健性 検証内容: 不完全な報酬モデルを用いたRLHFにおけるReward Hackingの影響 結果 (Fig 5): CoLAはKL制約が弱い場合(KL = 0.00)でも、Baselineに比べてReward Hackingに強い Baselineは意味のない質問を繰り返すなどの縮退が見られたが、CoLAは回答能力を維持 Policy Modelのみを学習し、Language World Model (言語能力) を固定しているため頑健 Win rate vs Baseline Win rate (KL=0 vs KL=0.01) | 16 of 17

17.

まとめ CoLA (Controlling Large Language Models with Latent Actions) LLMを「Policy Model」と「Language World Model」に分離 巨大なトークン空間ではなく、コンパクトな潜在アクション空間でRLを行う 利点 探索効率の向上: Math500やエージェントタスクで高性能 高い制御性: MCTSなどの探索アルゴリズムとの親和性が高い 頑健性: Reward Hackingに対して強く、言語能力を維持しやすい 今後の展望 より多様なBase Modelでの検証 さらに複雑なタスクへの応用 | 17 of 17