【DL輪読会】CONTINUAL LEARNING OF DIFFUSION MODELS WITH GENERATIVE DISTILLATION

3.3K Views

June 14, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

CONTINUAL LEARNING OF DIFFUSION MODELS DEEP LEARNING JP [DL Papers] WITH GENERATIVE DISTILLATION Hiroshi Sekiguchi, Academic Consultant, Morikawa Lab http://deeplearning.jp/ 1

2.

書誌情報 • “Continual Learning of Diffusion Models with Generative Distillation”, Sergi Masip*, Pau Rodríguez†, Tinne Tuytelaars‡, and Gido M. van de Ven‡ * Master in Computer Vision (Barcelona), † Apple, ‡ KULeuven 3rd Conference on Lifelong Learning Agents (CoLLAs), 2024 https://arxiv.org/pdf/2311.14028 • 概要 – 画像生成diffusion modelにおける複数タスクの継続学習(Continual Learning) – 継続学習: 直前のタスクで学んだことを忘却せずに次のタスクで新しいことを追 加で学習すること→往々にして起こる壊滅的忘却を防止する方法がポイント – 直前のタスクを学習したmodelを教師にし、現在のタスクを学習するmodelを生徒 にしたteacher-studentのknowledge distillation(generative distillationと呼ぶ)を用い て、忘却を抑える • 動機 – diffusion model関連の研究を総括的にReviewしていたところ、 diffusion modelにお ける継続学習の方法に興味を持った 2

3.

アジェンダ • 背景 • 提案手法 • 評価 • まとめ • 感想 3

4.

背景: Diffusion Modelの概念 • Diffusion modelとは – 生成モデルの一手法 (他に、VAE、GANなど) – 高精細の画像生成、テキスト→画像生成、3D point cloud画像生成、音声生成などへ発展 – 長所:最適化がGANに比べて多様なデータの生成が可能、学習が二乗誤差最小化と単純 – 短所:生成に時間が掛かる • Diffusion modelの考え方 – 「ノイズ(潜在変数)から元データへの変換」を学習するのは難しい ⇒「元データからノイズへの変換」は容易く実現でき、その逆変換を学習すれば良い ノイズ 元データ 4

5.

背景: Diffusion Model (DDPM)の数式化 • Diffusion modelの数式化(Denoising Diffusion Probabilistic Models: DDPM)[1] – Forward process:ガウスノイズを追加 𝑥𝑡 = 1 − 𝛽𝑡 𝑥𝑡−1 + 𝛽𝑡 𝜖 𝜖: ガウスノイズ、𝛽𝑡 : 分散(小さな定数) → 𝑞 𝑥𝑡 𝑥𝑡−1 ≡ 𝒩(𝑥𝑡 ; 1 − 𝛽𝑡 𝑥𝑡−1 , 𝛽𝑡 𝐼) → 𝑞 𝑥𝑡 𝑥0 = 𝒩(𝑥𝑡 ; 𝛼𝑡 𝑥0 , 1 − 𝛼𝑡 𝐼) 𝛼𝑡 ≡ 1 − 𝛽𝑡 , 𝛼𝑡 ≡ ς𝑇𝑠=1 𝛼𝑠 – Reverse process:ノイズ除去を学習 元データ ノイズ • 𝑝𝜃 𝑥𝑡−1 𝑥𝑡 = 𝒩(𝑥𝑡−1 ; 𝜇𝜃 𝑥𝑡 , 𝑡 , 𝜎𝑡2 𝐼) , 𝜇𝜃 𝑥𝑡 , 𝑡 をneural network(係数θ)で学習、 𝜎𝑡2:定数 • VAEの概念で学習を定義: 𝑥0 ~𝒟(学習入力データ空間:元データ空間) 潜在変数𝑧 = 𝑥1:𝑇 として 目的関数:元データ𝑥0の負の対数尤度 − log 𝑝𝜃 (𝑥0)の変分下限の最小化でθを学習 𝑝 (𝑥 ) すなわち、− log 𝑝𝜃 𝑥0 ≤ − ‫𝑥(𝑞 ׬‬1:𝑇 |𝑥0 ) log 𝜃 0:𝑇 𝑑𝑥1:𝑇 ≡Loss 𝑞(𝑥1:𝑇 |𝑥0 ) 実際は、 neural networkを𝜇𝜃 𝑥𝑡 , 𝑡 の予測→デノイズするべきノイズ𝜖𝜃 𝑥𝑡 , 𝑡 の予測に変えると 𝐿𝑜𝑠𝑠 = 𝔼𝑡,𝑥0 ,𝜖 𝜖 − 𝜖𝜃 𝑥𝑡 , 𝑡 2 ,𝑥 = 𝑡 𝛼𝑡 𝑥0 + 1 − 𝛼𝑡 𝜖 となる 5

6.

背景: Diffusion Model(DDPM)の学習手順 • DDPMにおける学習手順:𝐿𝑜𝑠𝑠 = 𝔼𝑡,𝑥0,𝜖 𝜖 − 𝜖𝜃 𝑥𝑡 , 𝑡 2 ,𝑥 = 𝑡 𝛼𝑡 𝑥0 + 1 − 𝛼𝑡 𝜖 • DNNは各時刻で同じmodelを共通して学習 • DNN:self-attention付きU-Net • 問題点:高画質の生成にはT=1000が必要→生成が遅い、計算コストが大きい 時刻情報 𝑡~𝑈(1, 𝑇) 学習 データ 𝑥0 ~𝐷 モデルを更新 DNN (𝜃) ノイズ付データ 𝑥𝑡 = 𝛼𝑡 𝑥0 + 1 − 𝛼𝑡 𝜖 推定した ノイズ𝜖𝜃 二乗誤差 の𝑡, 𝑥0 , 𝜖に 渡る平均 を最小化 ガウス ノイズ 𝜖~𝒩(0,1) 6

7.

背景: Diffusion Model(DDPM)の生成手順 • 時間𝑇から𝑇 ステップのReverse processで𝑥0 計算: 𝑝𝜃 𝑥𝑡−1 𝑥𝑡 = 𝒩(𝑥𝑡−1 ; 𝜇𝜃 𝑥𝑡 , 𝑡 , 𝜎𝑡2 𝐼) を用いて生成の手順を下記に示す • DDPMの短所:生成にT=1000ステップ掛かる 𝑝𝜃 𝑥𝑡−1 𝑥𝑡 𝑥𝑇 ・・・ 𝑥𝑡−1 𝑥𝑡 𝑥𝑇 ~𝒩(0,1) ・・・ 生成データ 𝑥0 𝑥𝑡−1 = 𝜇𝜃 𝑥𝑡 , 𝑡 + 𝜎𝑡 𝑧 𝜇𝜃 𝑥𝑡 , 𝑡 = 1 1 − 𝛽𝑡 𝑥𝑡 − 𝛽𝑡 1 − 𝛼𝑡 𝜖𝜃 𝑥𝑡 , 𝑡 𝑥𝑡 と𝑡が与えられれば、NNから𝜖𝜃 𝑥𝑡 , 𝑡 が得られる 𝑥0 z~𝒩(0,1) 7

8.

背景: Diffusion Model(DDIM)による改良 • Diffusion Model(DDPM)の短所: – 生成時間が長い→与えられた計算リソース内で生成できる画像数が少ない→複 数タスクの継続学習に不向き→生成時間が速い解法が欲しい – 確率過程なので、生成を同じ𝑥𝑇 から始めても同じ生成結果𝑥0 は得られない→継続 学習の再現性が担保できない→決定的サンプリングが欲しい、でも、学習済 DDPMのModelは使いたい • 対応法:Denoising Diffusion Implicit Models(DDIM)[2]を使う – DDIMは、DDPMと同じ形の損失関数に対する過程→学習が同じ→DDPMで学習した Model(𝜖𝜃 )はDDIMの生成過程で使いまわし可能 – 具体的には、DDIMの生成過程は、DDPMの拡散過程を拡張した過程で𝜎𝑡 = 0に固 決定的サンプルが可能な過程 定したもので、決定的サンプリングを実現 – DDIMの生成過程は、時間ステップ<10程度で高精細 画像が得られる DDPMの拡散過程 拡散過程を拡張した過程 8

9.

背景: Continual Learning • 複数タスク(𝒯 = {𝒯1, , 𝒯2 , ⋯ , 𝒯𝐵 })の継続的な学習 • 問題点:破壊的忘却:以前のタスクで得た能力を現在のタスク学習で失う現象 • 先行研究の対応手法 ① アーキテクチャ・パラメータ分離手法:アーキテクチャやパラメターをタスク毎に個別に 学習する方法 ② 正則化手法:以前のタスク知識を現在のタスクに保存するために、以前のタスクの知識を 正則化項として損失関数に加える方法 ③ リハーサル手法:現在のタスクの学習データに以前のタスクの代表的なデータを用いて学 習を補完する( experience replay )方法 ①アーキテクチャ・パラメータ分離手法 アーキテ クチャ・ パラメー タ1 アーキテ クチャ・ パラメー タ2 ・・・ アーキテ クチャ・ パラメー タN ③リハーサル手法 一部をcopy ②正則化手法 タスク 𝒯𝑖−2 の学習 data Model 𝑖 − 2 タスク𝒯𝑖−1 の学習 data Model 𝑖 − 1 experience replay Experience buffer Loss=(現在タスクLoss)+λ(以前タスクの知識の正則項) タスク 𝒯𝑖 の学習 data Model 𝑖 9

10.

背景:Diffusion modelにおける継続学習(generative replay:GR) • Diffusion modelにおける継続学習の先行研究 – リハーサル手法が用いられている ①はmodel規模∝タスク数が欠点、②はタスク知識の正則化へのマッピングが難しい – 以前のタスクの代表的データそのものを引き継ぐのではなく、以前の生成モデル を引き継いで、そのモデルから以前のタスクのデータを生成して現在のタスクの 学習に使う方法(generative replay) – この方法でも壊滅的忘却が起こる[3] タスク𝒯𝑖−1 の学習 data Generative replay Model 𝑖 − 1 copy Model 𝑖 − 1 以前 の タス ク のdata 𝑥𝑡 Model 𝑖 タスク 𝒯𝑖 の学習 data Generative replayのLossreplay=𝔼𝑡,𝑥𝑡 ~𝜖𝜃෡ 𝑖−1 𝜖 − 𝜖𝜃𝑖 𝑥𝑡 , 𝑡 2 , 𝑥𝑡 = 𝛼𝑡 DDIM𝑁 𝜖𝜃෡𝑖−1 + 1 − 𝛼𝑡 𝜖, DDIM𝑁 𝜖𝜃෡𝑖−1 は𝜖𝜃෡𝑖−1 を使ってDDIMでN回で生成した画像(高精細画像に生成されている) 1 1 10 タスク𝒯𝑖 全体のLoss𝐺𝑅= (タスク𝒯𝑖 の学習 dataでの学習Loss)+(1− )(Generative replayのLossreplay)

11.

提案手法:generative distillation(GD)による Diffusion model継続学習) • Diffusion modelによる継続学習で、壊滅的忘却を低減する方法としてgenerative distillation を提案 • Teacher-Student distillation model – Teacher model: タスク𝒯𝑖−1 で学習済のノイズ予測のneural network 𝜖𝜃෡ 𝑖−1 – Student model: タスク𝒯𝑖 で学習対象であるノイズ予測のneural network 𝜖𝜃𝑖 – 壊滅的忘却の原因の一つに、DDIM𝑁 𝜖𝜃෡ 𝑖−1 でのreplay画像生成に不具合がある画像が𝜖𝜃𝑖 の学習で誤 差を生じるため。本方式であれば、後続の正しいreplay画像での学習で修正可能(generavtive Teacher-Student Distillation replayでは修正不可能) タスク𝒯𝑖−1 の学習 Model 𝑖 − 1 data copy Generative Distillation 𝑥𝑡 Model 𝑖 − 1 Model 𝑖 タスク 𝒯 の学習 data 以前 の タス ク のdata 𝑖 Generative DistillationのLossdistil=𝔼𝑡,𝑥𝑡 ~𝜖𝜃෡ 𝑖−1 𝜖𝜃෡𝑖−1 𝑥𝑡 , 𝑡 − 𝜖𝜃𝑖 𝑥𝑡 , 𝑡 2 , 𝑥𝑡 = 𝛼𝑡 DDIM𝑁 𝜖𝜃෡𝑖−1 + 1 − 𝛼𝑡 𝜖, DDIM𝑁 𝜖𝜃෡𝑖−1 は𝜖𝜃෡𝑖−1 を使ってDDIMでN回で生成した画像(高精細画像に生成されている) 1 1 タスク𝒯𝑖 全体のLoss𝐺𝐷= (タスク𝒯𝑖 の学習 dataでの学習Loss)+(1− )λ(Generative distillationのLossdistil) 11 𝑖 𝑖

12.

評価方法 • Generative distillation(GD)の性能をベースラインと比 較評価する • ベースライン:generative replay(GR)学習, joint学習 ( 全タスク同時学習)、finetuning(タスクを順次 finetuning) • Dataset: Fashion-MNIST、CIFAR-10(10個のクラス) • タスクを5つのタスクに分け、5つのタスクを順 番に継続学習。2つのdatasetのそれぞれで、2class 毎の生成を1タスクとする • 評価尺度:生成画像の定性的画質、定量的品質 (Fréchet Inception distance:FIDとKullback-Leibler divergence:KLD)、並行して学習した画像認識分類 器による分類正解率(ACC) • GDとGRで、replay数とタスク新規学習数は同数 Fashion-MNIST CIFAR-10 12

13.

評価:Generative replayの客観的画質評価 • Generative Replay(GR)では、第二タスク以降に、壊滅的破壊が見られる • GRは、壊滅的忘却が生じている(画像がボケてしまう) • これは、GRでは、デノイジング能力が悪くなることに起因する Fashion -MNIST CIFAR-10 13

14.

評価結果: Generative distillationの定性的画質評価 • 全タスク(5タスク)の学習 を終了した後の生成画像の 定性的画質 • Generative distillation(GD)は、 GRに比較して、壊滅的忘却 が著しく改善している • GDは、過去に学習した知識 を維持可能 14

15.

評価結果: Generative distillationの定量的画質評価 • GDはGRと比べて定量的な画質指標で良好 • Jointが高位の性能目標で、Finetuningが低 位の性能目標 • GDの性能は、2つのdatasetの2つの性能 尺度で、いずれも、Joint>GD> Finetuning(=naïve) • GRは、CIFAR-10では、Finetuningよりも悪 くなる(CIFAR-10は画像生成として難し い画像) • 並行して学習した画像認識分類器で該当 クラスの正解率は、いずれも、 Joint>GD>GR>Finetuningが成立 15

16.

評価:DDIMの生成step数と生成画質 CIFAR-10 Fashion-MNIST • DDIMでの生成step数と生成画像の定性的画質の関係: Joint(全クラス同時学習)で作成 したmodelからDDIMのstep数2,10,100で生成した画像の品質→Fashion-MNISTでは2step、 CIFAR-10では10stepで許容画質になる→P.13-15の定量的画質評価では、GRとGDともに、 Fashion-MNISTで生成step数=2、CIFAR-10で生成step数=10で学習したものを使用 • DDIMで壊滅的忘却を阻止する生成step数:Fashion-MNISTでは、Generative Distillationは 2stepで許容可能だが、generative replayは100でもゴミあり→GRは壊滅的忘却が抑制不可 Fashion-MNIST 16

17.

評価:Generative Distillationの種類を評価 • GDの入力 𝑥𝑡 を幾つか変えてみた。 – LwF(learning w/o forgetting):右図 の𝑥𝑡 をタスク𝒯𝑖 のdataとする – Gaussian:右図の𝑥𝑡 を𝑥𝑡 ~𝒩(0,1)とす る →DDIMで0stepと同等 Teacher-Student Distillation 𝑥𝑡 • 驚くべきことに、LwFもGaussian もGRよりも良い • しかし、いずれも、GDよりは悪 い 17

18.

評価:GR/GD/Joint/Finetuning/LwF/Gaussianの定量比較 • 総合比較はGDが優秀 Good! Good! 18

19.

本手法の限界と課題 • DDIMで生成の計算量を下げたと言えども、GR/GDともに計算量が多く、ク ラス数が小ない小規模なdatasetに適用することしかできなかった • 大規模なdatasetに対処するためには、DDIMよりも生成効率の良いサンプル 生成法が必要である 19

20.

まとめ • Generative replayを用いたdiffusion modelの継続学習は、破滅的忘却が起 こったが、Generative distillationでは、破滅的忘却を抑制することができ た 20

21.

感想 • Generative distillationは、FinetuningやGenerative replayよりも壊滅的忘却 が抑制されていることは興味深い • Datasetが小規模な物しか示されず残念 • DDIMを凌駕するサンプリング効率のサンプリング手法が今後の発展には 急務 • Diffusion modelは、Continuous normalizing flow, fow matching, Conditional flow matchingと展開が速いが、これらのアルゴリズムで、継続学習の概 念はどうなるのか気になる。例えば – Flow系のアルゴリズムで、学習に掛かる計算量と画像生成に掛かる計算量は、そ れぞれ、DDPMとDDIMと比較して減少できるか? など 21

22.

参考文献 [1] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in neural information processing systems 33 (2020): 68406851. [2] Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising diffusion implicit models." arXiv preprint arXiv:2010.02502 (2020). [3] Smith, James Seale, et al. "Continual diffusion: Continual customization of textto-image diffusion with c-lora." arXiv preprint arXiv:2304.06027 (2023). 22

23.

END 23