2.3K Views
November 07, 24
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] “BEYOND AUTOREGRESSION: FAST LLMS VIA SELF-DISTILLATION THROUGH TIME” 2024.11.07 Toshiharu Maeba, Matsuo-Iwasawa Lab (M1) http://deeplearning.jp/
書誌情報 紹介論文 タイトル BEYOND AUTOREGRESSION: FAST LLMS VIA SELF-DISTILLATION THROUGH TIME 出典: Arxiv(2024.10)preprint 著者: Justin Deschenaux, Caglar Gulcehre CLAIRE Lab, School of Computer and Communication 概要 • Self-Distillation Through Time(SDTT)を用いた離散拡散言語モデル • サンプリングステップ数を減らすことで推論を高速化 ※画像は出典記載のないものは、本論文から引用 2
目次 • 概要 • 関連研究 • 手法 • 実験 • まとめ・感想 3
概要:主な貢献 ⚫ SDTT(Self-Distillation Through Time)を導入し、一度に少なくとも32個のトー クンを生成することができ、 nucleus samplingによるGPT-2よりも優れた perplexityを実現 ⚫ SDTTは非常にシンプルで実装が簡単 ⚫ SDTTは、KVキャッシングを使用するARモデルよりも最大8倍高速にトークンを 生成可能 ⚫ 最大860Mのパラメータを持つ言語モデルに対するSDTTの有効性を実証 4
関連研究:MASKED DIFFUSION LANGUAGE MODELING ⚫ MDLMは、テキストのマスクされた部分を段階的に復元することで言語生成 ⚫ モデルは、テキストにノイズ(マスク)を加えていき、段階的に除去する過程で、隠されたマスク部分 を予測し、元のテキストを再構成 5
関連研究:MASKED DIFFUSION LANGUAGE MODELING ⚫ MDLM :データにノイズを付加するフォワード プロセスと、データの回復を学習するバックワード プロセス forward x:元のドキュメントxで定義されたone-hot分布 𝝅:定常分布、トークンをマスク 𝛼𝑡 :ノイズ注入スケジュール、t ∈ [0, 1] 、𝛼t ∈ [0, 1]、 𝛼t は t の減少関数(𝛼0 ≈ 1、𝛼1 ≈ 0) backward 𝛼 𝛼𝑠 t > s 、𝛼𝑡|𝑠 = 𝑡 損失関数 NELBO は、Ground Truth xとモデル予測 x𝜃 の間の重み付きクロス エントロピー損失に単純化 制約 ①ノイズ除去されたトークンは、サンプリング中に再マスクされない ②すでにノイズ除去されたトークンは、次のサンプリングステップに持ち越す 6
関連研究:知識蒸留 ⚫ 知識蒸留は、生徒モデルを訓練して、より複雑な教師モデルの予測を模倣する手法 ⚫ 蒸留の主な利点の 1 つは、大規模な LLM からのサンプリングに関連する推論コストを削減しながら、蒸留なし でトレーニングされた小規模なモデルのパフォーマンスを上回ることができること ⚫ 今回の手法では、ダイバージェンスδを使用して教師と生徒の予測を一致させる蒸留方法 𝜇s、𝜇𝑡 :それぞれ生徒と教師のAR分布 D:訓練データセット 7
提案手法:SDTT ⚫ 通常、ステップ数が少ないと、逆プロセスの近似精度が低下するため、サンプルの品質が低下 ⚫ SDTTは、推論時に複数のステップを学生に蒸留することにより、サンプリング速度を向上 ⚫ 蒸留した学生を教師として使用して、SDTTを複数回適用 8
提案手法:SDTT (m) (k) SDTT は、パラメーター ν のデノイザーを学習させて、 p𝜃 と p𝜈 の間のダイバージェンス d を最小化する問題 (k <m、mはkで割り切れるものとする (たとえば、m = 1024 と k = 512)) (m) p𝜃 :パラメータ θ のデノイザーを使用して、m ステップで生成されたサンプルの分布 x𝜃 と x𝜈 をそれぞれ教師からの多くのステップのデコードに使用されるデノイザー、生徒からの数ステップ のデコードに使用されるデノイザーとする (m) (k) サンプリング プロセスの唯一の学習可能な要素であるため、サンプリング分布 p𝜃 と p𝜈 を完全に決定 x𝜃 の予測と一致するように x𝜈 を学習するために、m/k ステップについて教師からサンプリング 9
提案手法:SDTT(アルゴリズム) 10
提案手法:SDTT(アルゴリズム) 11
実験設定 使用モデルなど ⚫ ベースモデル: diffusion transformer、MDLMの原論文のチェックポイントを再利用 ⚫ 訓練データセット:OpenWebTextデータセット • 学習率が 6e − 5、500ステップで直線的に増加し、その後は一定 • オプティマイザー:重み減衰がない Adam • バッチ サイズが 128、 評価指標 ⚫ MAUVE により、モデルがプロンプトにどの程度従うかを評価 ⚫ LAMBADAデータセットにより、精度とperplexityを評価 ⚫ 拡散モデルは、マスクされたすべてのトークンが 1 つのステップで正しく デコードされた場合にのみ正解とする 12
実験:目的関数 ⚫ SDTTでは損失関数として使用するダイバージェンスδを選択する必要があるため、平均二乗誤差(MSE)、 全変動距離(TVD)、KLダイバージェンス(KLD)のいずれが適しているかを実験 ⚫ KLDが最も良い結果 KLDのみ教師を上回る 生徒が1024ステップでなく16ステップで訓練しても GPT-2に匹敵する性能 13
アブレーション:ステップ数 ⚫ 蒸留損失の大きさは収束を確実に示していないため、より短いラウンドで実験 ⚫ 短いラウンドはわずかに最終的な生成的なperplexityを改善 ⚫ SDTTはELBOを直接最適化しないため、ステップ数が増えるとperplexityの増加が予想 LAMBADAの精度は、ラウンドが短くても変 化しなかった 14
アブレーション:教師のsampling steps ⚫ 教師モデルのターゲットは2サンプリングステップ→変化させるとどうなるか実験 ⚫ 全体として、一度に2ステップ以上を蒸留すると、パフォーマンスが低下 ⚫ この結果は、4ステップまたは8ステップで生成されたターゲットの不確実性が高いため、 生徒にとってタスクが難しすぎることを示唆 15
アブレーション:sampler ⚫ Analytical sampler はより高品質なサンプルが得られることが知られているため、ancestral sampling(ここま で使用)とanalytical samplerを比較 ⚫ Sampler による差はほとんどない 16
アブレーション:EMA ⚫ 重みの指数移動平均(EMA)を使用すると、拡散モデルからのサンプルの品質が向上することが知られ ているため、教師モデルに使用する場合を実験 ⚫ オプティマイザのみのリセットは効果がない ⚫ 重みのEMAを教師に使用すると、パフォーマンスがわずかに向上する可能性 17
実験結果:スケーリング ⚫ SDTTを最大860Mのパラメータを持つ大規模な離散拡散モデルに適用 ⚫ 蒸留により、性能が改善していることから、大きなモデルにも 適用可能な手法 小型のモデルが、最大モデルよりも優れたperplexityを達成 18
実験結果:latency ⚫ SDTTのレイテンシを、KVcachingを使用する自己回帰モデルと比較 ⚫ 32ステップでサンプリングすると4倍、16ステップだと8倍の改善 19
まとめ・考察 まとめ • SDTTは、パフォーマンスを維持しながらデコードステップの数を削減 • 生徒モデルは、KVキャッシングを使用するARモデルよりも最大8倍高速 • SDTTはより大きなモデルにも適用可能 • 今後の研究では、基本言語モデルから大量の補完を生成するタスクでSDTTを評 価する予定 感想 • 高速化が実現されているが、理論的な説明が少ない • 現状では、評価指標が少ないため、今後の研究に期待 20