2.2K Views
March 06, 25
スライド概要
DL輪読会資料
Masked Diffusion Modelの進展 Shohei Taniguchi, Matsuo Lab
発表概要 Masked Diffusion Model • 最近、拡散モデルでLLMを作る研究が少しずつ出てきている • LLaDA, Mercury • これらでは基本的にMasked Diffusionが使われている • 言語のような離散データのための拡散モデル • これまでの研究の流れをざっと解説 2
背景 Large Language Diffusion Models (LLaDA) • 8Bスケールで学習された拡散モデルベースのLLM • SFTでの事後学習までやっている • 同サイズのLLaMA3と同等か上回る性能 • 今日はこれをメインで紹介します 3
背景 Mercury • Inception Labsから出た拡散モデルベースのコード生成LLM • Diffusionの最初の論文を出したStefano Ermonのスタートアップ • 生成がかなり速い • 論文が出ていないので詳細はわからないが おそらくMasked Diffusionが使われている 4
アジェンダ • 離散拡散モデルの先行研究 • Structured Denoising Diffusion Models in Discrete State-Spaces • Simple and Effective Masked Diffusion Language Models • Simplified and Generalized Masked Diffusion for Discrete Data • Large Language Diffusion Models • 2/14に出た8Bスケールの拡散モデルLLM 5
拡散モデルおさらい 連続データの場合 • データにノイズがかかっていく確率過程を考えると、その逆過程をデータの 生成過程と見做せる • 逆過程に登場するスコア関数を推定すれば良い • その学習は、乗せたノイズを予測することで行える Forward SDE (data noise) score function Reverse SDE (noise 6 data)
離散拡散モデル Structured Denoising Diffusion Models in Discrete State-Spaces • テキストのような離散的なデータの場合、順過程は状態遷移行列 Q で表現さ れるカテゴリ分布になる • Qの設計によって、最終的にどのような分布に収束するかが変わる 7
離散拡散モデル Structured Denoising Diffusion Models in Discrete State-Spaces • 順過程の設計としては主に以下がある 1. Uniform : 徐々に一様分布に近づいていく過程 2. Absorbing : 1つの値(マスク)に収束していく過程 • これが最近使われるMasked Diffusionに対応 • 1度マスクされたら他の値には移らない 8
Masked Diffusion Models • トークンが徐々にマスクされていく過程を順過程とすると 逆過程は徐々にマスクが取り除かれていく過程になる • 自己回帰と違い、ランダムな順番でトークンが生成される Proprietary + Confidential 500 steps Mayor � � said � � � � � � � that � new plan � � � � � � � 700 steps Mayor � Bowser said � meetings � Commissioner � on Thursday that � new plan will be � board in � � 850 steps Mayor Muriel Bowser said after meetings � Commissioner � on Thursday that � new plan will be � board in December � 1000 steps Mayor Muriel Bowser said after meetings with Commissioner Busby on Thursday that the new plan will be on board in December. 9
Masked Diffusion Models 目的関数 • Masked diffusionは、各トークンを確率 t でマスクして、残りのトークンから マスクを予測するというBERTのような方法で学習できる • マスク確率 t も0~1からランダムに選ばれる Pre-training Mask all tokens independently Mask ratio 𝑡 ∼ 𝑈(0,1) • この目的関数は、尤度の下界を最大化する 10 Mask predictor Mask token Remask Non-mask token Random mask
Masked Diffusion Models サンプリング • 逆過程から厳密にサンプリングする場合は、全トークンがマスクされた状態 から、ランダムな順番で1トークンずつ生成する形になる • ただし、実際には一度に複数のトークンを生成してもあまり問題ない • Mercuryが生成が速いのはこのため(おそらく) 11
Large Language Diffusion Models • Masked diffusionベースで初めて8Bスケールで学習されたモデル • 2.3T token事前学習した後に、4.5MのペアデータでSFT • 同サイズのLLaMAと同等か上回る性能 12
Large Language Diffusion Models • 通常のLLMと同じように、事前学習 => 事後学習 (SFT) の流れ • SFTの際には、プロンプト部分はマスクせずに学習させる Pre-training SFT Mask all tokens independently Prompt Sampling Response Prompt Mask ratio 𝑡 ∼ 𝑈(0,1) ... Mask predictor Mask predictor Remask Mask token Remask Non-mask token Random mask ... 13 𝑡= 1 An intermediate step Mask predictor Response 𝑡= 0
Large Language Diffusion Models 事前学習後の性能 • 同サイズのLLaMA3とほぼ同等の性能 • ただし、データ量は1/6くらい • データ量がほぼ同等のLLaMA2よりは だいぶ良い 14
Large Language Diffusion Models 事後学習後の性能 • こっちはLLaMA3に大体負けてる • LLaMA3はSFT+RLをしているという 違いはある • LLaMA2よりはだいぶ良い 15
Large Language Diffusion Models • 従来のLLMと同じように投入計算量に対してきちんと性能がスケールすること も確認 16
Large Language Diffusion Models Reversal curseの回避 • LLMの課題として、テキストの順序を入れ替えるだけで、性能が著しく下がる reversal curseがよく上げられる • 「AはBである」で学習したときに「BはAである」を正しく予測できない • 次トークン予測しか学習しないことが主な原因 • LLaDAはこれが起こりにくい 17
Masked Diffusion 課題 • Masked diffusionは基本的にTransformer Encoderで実装される • Causal maskがないため、KVキャッシュが使えない • Diffusionにより適したアーキテクチャがあるはず • より巨大なモデルでもちゃんとスケールするのかは、まだ不透明 • RLでの事後学習でもworkするかなども気になる 18
まとめ • テキストのような離散データの拡散モデルは、Masked Diffusionが主流 • Masked Diffusionは、BERTのようにランダムマスクの穴埋め予測で学習できる • LLaDAやMercuryの登場でちゃんとスケールしそうなことがわかってきので 今後もっと色々出てきそう 19