26.6K Views
June 25, 24
スライド概要
YouTubeはこちら→https://youtu.be/hbgUG8Ujlpg
DL輪読会資料
拡散モデルの数理 ① DDPMを理解する 谷口尚平 2024年3月19日
発表概要 目標 • 拡散モデルの数理的な背景をきちんと理解する ‣ 何を最大化して学習しているか ‣ なぜこんなに上手くいくのか • 1回では難しいので、3回くらいに分けて説明する
発表概要 アジェンダ 1. DDPMを理解する • VAEとの関係 • ロスの導出 2. スコアマッチングを理解する 3. スコアベース拡散モデルを理解する
発表概要 アジェンダ 1. DDPMを理解する • VAEとの関係 • ロスの導出 2. スコアマッチングを理解する 3. スコアベース拡散モデルを理解する
拡散モデルの歴史 VAE Implicit Score Matching [Kingma, et al., 2013] [Hyvärinen, et al., 2005] 潜在変数モデル DPM スコアマッチング Denoising Score Matching [Sohl-Dickstein, et al., 2015] [Vincent, et al., 2011] DDPM NCSN [Ho, et al., 2020] [Song, et al., 2020] 統一的理解 DDPM++, NCSN++ [Song, et al., 2021]
Variational autoencoder [Kingma, et al., 2013] 潜在変数モデル • ノイズ z を入力として、データ x を出力するモデル • fθ が通常ニューラルネットで、これを学習したい 2 𝒩 𝒩 • 分散パラメータ σ も学習対象にすることが多い pθ (x) = p (z) pθ (x ∣ z) dz ∫ p (z) = (z; 0, I) 2 pθ (x ∣ z) = (x; fθ (z), σ I)
Variational autoencoder 目的関数 • データ分布とモデルの分布のKL最小化 DKL (pdata (x) ∥pθ (x)) • モデルが生成するデータを学習データに近づける 𝒩 𝒩 • 対数尤度 log pθ (x) を最大化と同値 pθ (x) = p (z) pθ (x ∣ z) dz ∫ p (z) = (z; 0, I) 2 pθ (x ∣ z) = (x; fθ (z), σ I)
Variational autoencoder Evidence lower bound, ELBO • 対数尤度は期待値を含むため計算できない log p (x) = log p(z) [pθ (x ∣ z)] • 変分下界の最大化で代替 log p (x) ≥ qϕ(z ∣ x) [ log 𝒩 𝒩 𝒩 𝔼 𝔼 = ℒ (x; θ, ϕ) qϕ (z ∣ x) ] pθ (x, z) pθ (x) = p (z) pθ (x ∣ z) dz ∫ p (z) = (z; 0, I) 2 pθ (x ∣ z) = (x; fθ (z), σ I) qϕ (z ∣ x) = (z; μϕ (x), Σϕ (x))
Variational autoencoder エンコーダ qϕ の学習 ℒ (x, θ, ϕ) = log pθ (x) − DKL (qϕ (z ∣ x) ∥pθ (z ∣ x)) • 変分下界 ℒ の最大化はエンコーダ qϕ から見ると DKL (qϕ (z ∣ x) ∥pθ (z ∣ x)) の最小化 𝒩 𝒩 𝒩 pθ (x) = p (z) pθ (x ∣ z) dz ∫ • エンコーダを学習することで下界がタイトになる p (z) = (z; 0, I) 2 pθ (x ∣ z) = (x; fθ (z), σ I) qϕ (z ∣ x) = (z; μϕ (x), Σϕ (x))
Variational autoencoder 課題 • VAEはエンコーダとデコーダを同時に学習する必要があり、うまく学習する のが難しい Posterior collapse: エンコーダがデータの情報を無視してしまう現象 • エンコーダの表現力が分布の選び方によって制限される ‣ 通常は正規分布にするが、単峰な分布しか近似できず、表現力が限られる
DDPM
[Ho, et al., 2020]
<latexit sha1_base64="l4LvSgM7PR7I/kkuy5soikK4gpU=">AAAEoXictVLditNAFE7XqGv92a5eejOYLexKLU0VFKRQ9EYvhCrb3YUklOlk2g6dnzBzYrcb8zK+lU/gazhJK6atuiB4YODM+T/n+8YJZwY6nW+1vRvuzVu39+/U7967/+CgcfjwzKhUEzokiit9McaGcibpEBhwepFoisWY0/Px/G3hP/9MtWFKnsIyoZHAU8kmjGCwplHjeygwzAjThNM4Kz/jSXaZj05zFHIlp5pNZ4C1VgsUkliB2TX/oQLYCpe/4rJwZhJM6NPMJyLPt9IM0SwBA0tOUaVGBs/8/J8mWVRH6eSjhtdpd0pBu4q/VjxnLYPR4d7XMFYkFVQC4diYwO8kEGVYA7P183qYGmr3meMpDawqsaAmykpEctS0lhhNlLZPAiqt1YwMC2OWYmwjiynNtq8w/s4XpDB5FWVMJilQSVaNJilHoFABL4qZpgT40irYntTOisgMa0zAkqC+0QbY/MquIfCcYssbsBH1UNIFUUJgGVePGfhR1qyj1YETXAaH/SqAnp836/lGftUfdNcFiqbBT8L2jouQdvE9iVAoVUyDWONFa5XVYlJSjezEPT+BlmCSiVQgw65or2vBaE0Y5z1e4D/VeBmhstwJyo5C0YeZ53vdo/z19lhVjly71+K6xRb/ZbO/rbLCS8HMwmVZ7W9zeFc567b95+3uxxde/82a3/vOY+eJc+z4zkun77xzBs7QIbUPNVP7Ustdz33vDtxPq9C92jrnkbMhbvAD81mObw==</latexit>
!
• エンコーダとしてデータに徐々にノイズが
! · · · ! xt
p✓ (xt 1 |xt )
! xt 1
q(xt |xt 1 )
<latexit sha1_base64="eAZ87UuTmAQoJ4u19RGH5tA+bCI=">AAACC3icbVC7TgJBFJ31ifhatbSZQEywkOyiiZQkNpaYyCMBspkdZmHC7MOZu0ay0tv4KzYWGmPrD9j5N87CFgieZJIz59ybe+9xI8EVWNaPsbK6tr6xmdvKb+/s7u2bB4dNFcaSsgYNRSjbLlFM8IA1gINg7Ugy4ruCtdzRVeq37plUPAxuYRyxnk8GAfc4JaAlxyzclbo+gaHrJQ8TB/AjnvsmcGZPTh2zaJWtKfAysTNSRBnqjvnd7Yc09lkAVBClOrYVQS8hEjgVbJLvxopFhI7IgHU0DYjPVC+Z3jLBJ1rpYy+U+gWAp+p8R0J8pca+qyvTRdWil4r/eZ0YvGov4UEUAwvobJAXCwwhToPBfS4ZBTHWhFDJ9a6YDokkFHR8eR2CvXjyMmlWyvZ5uXJzUaxVszhy6BgVUAnZ6BLV0DWqowai6Am9oDf0bjwbr8aH8TkrXTGyniP0B8bXL+1hmu8=</latexit>
加わる過程を考える
‣ zt の次元はすべてデータ x と同一
• モデル pθ は、逆にノイズを少しずつ除去
𝒩
𝒩
していく形になる
! · · · ! x0
sha1_base64="7yFrn0YPyuP5dVIvc7Tl2zcbS/g=">AAAB+HicbVBNSwMxEJ2tX7V+dNWjl2ARPJXdKuix6MVjBfsB7VKyaXYbmk2WJKvU0l/ixYMiXv0p3vw3pu0etPXBwOO9GWbmhSln2njet1NYW9/Y3Cpul3Z29/bL7sFhS8tMEdokkkvVCbGmnAnaNMxw2kkVxUnIaTsc3cz89gNVmklxb8YpDRIcCxYxgo2V+m65x6WIFYuHBislH/tuxat6c6BV4uekAjkafferN5AkS6gwhGOtu76XmmCClWGE02mpl2maYjLCMe1aKnBCdTCZHz5Fp1YZoEgqW8Kgufp7YoITrcdJaDsTbIZ62ZuJ/3ndzERXwYSJNDNUkMWiKOPISDRLAQ2YosTwsSWYKGZvRWSIFSbGZlWyIfjLL6+SVq3qn1drdxeV+nUeRxGO4QTOwIdLqMMtNKAJBDJ4hld4c56cF+fd+Vi0Fpx85gj+wPn8AXOGk5o=</latexit>
xT
<latexit sha1_base64="XVzP503G8Ma8Lkwk3KKGZcZJbZ0=">AAACEnicbVC7SgNBFJ2Nrxhfq5Y2g0FICsNuFEwZsLGMYB6QLMvsZDYZMvtg5q4Y1nyDjb9iY6GIrZWdf+Mk2SImHrhwOOde7r3HiwVXYFk/Rm5tfWNzK79d2Nnd2z8wD49aKkokZU0aiUh2PKKY4CFrAgfBOrFkJPAEa3uj66nfvmdS8Si8g3HMnIAMQu5zSkBLrlmO3R4MGZBSLyAw9Pz0YeKmcG5P8CNekKDsmkWrYs2AV4mdkSLK0HDN714/oknAQqCCKNW1rRiclEjgVLBJoZcoFhM6IgPW1TQkAVNOOntpgs+00sd+JHWFgGfq4kRKAqXGgac7p0eqZW8q/ud1E/BrTsrDOAEW0vkiPxEYIjzNB/e5ZBTEWBNCJde3YjokklDQKRZ0CPbyy6ukVa3YF5Xq7WWxXsviyKMTdIpKyEZXqI5uUAM1EUVP6AW9oXfj2Xg1PozPeWvOyGaO0R8YX7+bCp4F</latexit>
<latexit
• 階層的な z をもつVAE
pθ (x, z1:T) = p (zT)
q (z1:T ∣ x) =
T
∏
t=1
q (zt ∣ zt−1) =
pθ (zt−1 ∣ zt) =
T
∏
t=1
pθ (zt−1 ∣ zt)
q (zt ∣ zt−1)
z
;
1
−
β
z
,
β
I
t
t
t−1
t
(
)
(zt−1; μθ (zt, t), Σθ (zt, t))
ϵ∼ DDPM αt = 拡散過程 t (ϵ; 0, I) ∏ i=1 1 − βi , σt = 2 1 − αt エンコーダが正規分布なので、x から各時刻の zt へのジャンプも解析的に求まる zt = βt ϵ (1 − βt) (1 − βt−1) zt−2 + 𝒩 = 1 − βt zt−1 + 𝒩 ⋮ = αtx + σt ϵ ⇒ q (zt ∣ x) = 1 − 1 − β 1 − β ϵ ( ) ( ) t t−1 ( ) (zt; αtx, σt I)
ϵ∼ DDPM αt = ELBO t (ϵ; 0, I) ∏ i=1 1 − βi , σt = 2 1 − αt 𝔼 pθ (x, z1:T) ℒ (x; θ, ϕ) = q(z1:T ∣ x) log [ q (z1:T ∣ x) ] = T ∑ 𝔼 t=1 =− q(zt, zt−1 ∣ x) T ∑ 𝒩 𝔼 t=1 [ log pθ (zt−1 ∣ zt) q (zt−1 ∣ x, zt) ] q(zt ∣ x) [DKL (q (zt−1 ∣ x, zt) ∥pθ (zt−1 ∣ zt))]
ϵ∼ DDPM αt = ELBO t∼ ℒ (x; θ) =− t (ϵ; 0, I), zt ∼ q (zt ∣ x) ∏ i=1 1 − βi , σt = 2 1 − αt (t; {1,…, T}) T D q z ∣ x, z ∥p z ∣ z ( ) ( ) KL t−1 t θ t−1 t [ ( )] ∑ t=1 D q z ∣ x, z ∥p z ∣ z ( ) ( KL t−1 t θ t−1 t))] [ ( 𝔼 =−T⋅ • 時刻 t をランダムに選んで、その時刻でのKL divergenceを計算することで、 𝒩 𝒰 𝔼 ELBOの不偏推定量を得られる
ϵ∼ DDPM αt = ELBO t (ϵ; 0, I), zt ∼ q (zt ∣ x) ∏ t∼ i=1 1 − βi , σt = (t; {1,…, T}) デコーダを以下の構造にすると、KLが加えたノイズの予測誤差の形になる μθ (xt, t) = 1 − βt ( 1 xt − ϵθ (xt, t) ) 1 − αt βt ⇒ DKL (q (zt−1 ∣ x, zt) ∥pθ (zt−1 ∣ zt)) 𝒩 𝒰 = 1 2 (1 − βt) (1 − αt) ϵ − ϵθ (xt, t) 2 2 1 − αt
ϵ∼ DDPM αt = シンプルなロス関数 t∼ t (ϵ; 0, I), zt ∼ q (zt ∣ x) ∏ i=1 1 − βi , σt = (t; {1,…, T}) • 実験的には、ELBOの係数の (1 − βt) (1 − αt) を無視した方が安定する ℒsimple (x; θ) = [ ϵ − ϵθ (zt, t) 𝒩 𝒩 𝒰 𝔼 • この場合、シンプルにノイズ予測の二乗誤差の形になる 2 ] 2 1 − αt
DDPM • CelebA-HQなどの高い解像度の 画像生成も安定して学習できる
DDPM まとめ • 階層的なVAEのエンコーダとして、データ x に逐次的にノイズを加える過程を 考えると、そのELBOは加えたノイズを予測するモデルの二乗誤差と対応する • エンコーダがパラメータをもたないので、VAEよりも安定して学習ができる • ELBOを簡略化したロスを使うことで、高解像度な画像生成なども実現
拡散モデルの数理 ② スコアマッチングを理解する 谷口尚平 2024年4月30日
発表概要 アジェンダ 1. DDPMを理解する • VAEとの関係 • ロスの導出 2. スコアマッチングを理解する 3. スコアベース拡散モデルを理解する
拡散モデルの歴史 再掲 VAE Implicit Score Matching [Kingma, et al., 2013] [Hyvärinen, et al., 2005] 潜在変数モデル DPM スコアマッチング Denoising Score Matching [Sohl-Dickstein, et al., 2015] [Vincent, et al., 2011] DDPM NCSN [Ho, et al., 2020] [Song, et al., 2020] 統一的理解 DDPM++, NCSN++ [Song, et al., 2021]
Score-based Model スコアベースモデル • スコア関数 ∇x log p (x) をニューラルネット sθ̂ (x) でモデル化する • x をどの方向に動かせば、尤もらしい( = p (x) が高い)サンプルになるかを 表している
Langevin Monte-Carlo ランジュバン・モンテカルロ法 • スコア関数 ∇x log p (x)を使って p (x) からサンプリングする方法 x ← x′ ∼ (x), x′ ; x + η ∇ log p x (  𝒩  • 実際には、 ∇x log p (x) は推定器 sθ̂ (x) で代替 2ηI)
x ∼ pdata (x) Score Matching スコアマッチング 1 JSM (θ) = [ 2 ∇x log pdata (x) − sθ̂ (x) ] 2 • 推定器 sθ̂ (x) を真のスコア ∇x log p (x) に近づけるように学習する 𝔼 • しかし、これは ∇x log p (x) がわからないので、実際には計算できない
x ∼ pdata (x) Implicit Score Matching [Hyvärinen, et al., 2005] JISM (θ) = 1 [2 sθ̂ (x) • これは計算できる 𝔼 • JSM の最小化は JISM の最小化と等価 2 + tr ( ∇x sθ̂ (x)) ]
𝒪 x ∼ pdata (x) Implicit Score Matching 暗黙的スコアマッチング JISM (θ) = 1 [2 課題 • tr ( ∇x sθ̂ (x)) の計算が厳しい d 𝔼 • x ∈ ℝ のとき (d ) かかる 2 sθ̂ (x) 2 + tr ( ∇x sθ̂ (x)) ]
Denoising Score Matching [Vincent, et al., 2011] • x の代わりに x にノイズを乗せた xt のスコア関数を推定することを考える pdata(x) [q0t (xt ∣ x)] xt ∼ qt (xt) = q0t (xt ∣ x) = 𝒩 𝔼 𝔼 1 J˜SM (θ; t) = 2 [ 2 (xt; αtx, σt I) ∇xt log qt (xt) − sθ̂ (xt) 2 ]
Denoising Score Matching デノイジングスコアマッチング J˜DSM (θ; t) = 1 [2 xt ∼ qt (xt) = pdata(x) [q0t (xt ∣ x)] 2 q0t (xt ∣ x) = (xt; αtx, σt I) ϵ ∼ (ϵ; 0, I) ∇xt log q0t (xt ∣ x) − sθ̂ (xt) 2 ] • ターゲットを ∇xt log qt (xt) から条件付きスコア ∇xt log qt (xt ∣ x) に変更 𝔼 𝒩 𝒩 𝔼 ˜ ˜ J J の最小化は の最小化と等価 • SM DSM
Denoising Score Matching デノイジングスコアマッチング 𝔼 𝒩 𝒩 𝔼 𝔼 𝔼 J˜DSM (θ; t) = 1 [2 = 1 2 = 1 2 [ 2σt xt ∼ qt (xt) = pdata(x) [q0t (xt ∣ x)] 2 q0t (xt ∣ x) = (xt; αtx, σt I) ϵ ∼ (ϵ; 0, I) ϵθ̂ (xt) = − σt ⋅ sθ̂ (xt) ∇xt log q0t (xt ∣ x) − sθ̂ (xt) 2 1 − ϵ − sθ̂ (xt) σt ϵ − ϵθ̂ (xt) 2 ] 2 ]
Denoising Score Matching デノイジングスコアマッチング J˜DSM (θ; t) = 1 2 2σ [ t xt ∼ qt (xt) = pdata(x) [q0t (xt ∣ x)] 2 q0t (xt ∣ x) = (xt; αtx, σt I) ϵ ∼ (ϵ; 0, I) ϵθ̂ (xt) = − σt ⋅ sθ̂ (xt) ϵ − ϵθ̂ (xt) 2 ] • データにノイズを加えた分布 qt のスコアマッチングは、加えたノイズを予測 した際の二乗誤差の最小化によって行うことができる • αt ≈ 1, σt ≈ 0とすれば、qt ≈ pdataになるので、推定したスコア関数を使って 𝔼 𝒩 𝒩 𝔼 ランジュバン・モンテカルロ法でデータをサンプリングできる
Denoising Score Matching デノイジングスコアマッチング 課題 • ランジュバン・モンテカルロ法は、多峰な分布からのサンプリングが難しい • 別の峰に移動することが稀にしか起こらず、収束するのに時間がかかる
xt ∼ qt (xt) = pdata(x) [q0t (xt ∣ x)] 2 q0t (xt ∣ x) = (xt; αtx, σt I) ϵ ∼ (ϵ; 0, I) ϵθ̂ (xt) = − σt ⋅ sθ̂ (xt) 2 t−1 2 αt = 1, σt = γ σ1 , γ > 1 NCSN [Song, et al., 2020] • DSMをいろんなノイズレベル t で学習させる 𝔼 JDSM (θ) = ∑ 2T [ T t=1 T 2 σt 1 = ∑ 2T [ t=1 ∇xt log q0t (xt ∣ x) − sθ̂ (xt, t) ϵ − ϵθ̂ (xt, t) 𝔼 𝒩 𝒩 𝔼 • ノイズ予測器 ϵθ̂ は、t も入力にとる ] 2 2 ]
NCSN Annealed LMC • はじめは高いノイズレベルの分布 qT からのサンプリングからスタートして、 徐々にノイズレベルを下げていく  𝒩  for t = T, …,1 do t−1 ηt = η ⋅ γ for k = 1,…, K do xt ← x′ ∼ xt+1 ← xt ̂ x′ ; x + η ⋅ s x , t , ( ) t t θ t ( 2ηt I)
NCSN Annealed LMC • 通常のLMCは、各峰の密度比をうまく捉えられていないが、annealingをする と改善する
NCSN • 徐々にノイズが除去されていく ような形でサンプリングされる
拡散モデルの数理 ③ スコアベース拡散モデルを理解する 谷口尚平 2024年4月30日
発表概要 アジェンダ 1. DDPMを理解する • VAEとの関係 • ロスの導出 2. スコアマッチングを理解する 3. スコアベース拡散モデルを理解する
Score SDE [Song, et al., 2021] • DDPMやNCSNは、確率微分方程式を用いることで統一的に定式化できる • どちらも T → ∞ の極限を考えると、ノイズレベルが連続的に変化する過程と して記述できる
確率微分方程式 Stochastic differential equation, SDE 拡散過程 dxt = f (xt, t) dt + g (t) dw • x0 ∼ pdata (x0)として、このSDEで生成される xt がしたがう分布を qt とする • t = 1 で x1 は完全なノイズになる • 直感的には、時刻 t が進むにつれて、xt にノイズが加わっていくイメージ
確率微分方程式 Stochastic differential equation, SDE 逆拡散過程 dxt = [f (xt, t) − g (t) ∇xt log qt (xt)] dt + g (t) dw 2 • x1 ∼ p1 (x1)として、このSDEで生成される xt がしたがう分布を pt とする • p1 = q1 なら、任意の t で pt = qt • t = 1 から時刻を巻き戻すことで、ノイズからデータが生成される
確率微分方程式 Stochastic differential equation, SDE 逆拡散過程 dxt = [f (xt, t) − g (t) ∇xt log qt (xt)] dt + g (t) dw 2 • 各時刻 t におけるスコア関数を推定できれば、この逆拡散過程をシミュレート して、データを生成できる t ∈ (0,1) でスコアマッチングすればよい!
xt ∼ qt (xt) = t ∼ (t; 0,1) Score Matching JDSM (θ) = λt [2 ∇xt log q0t (xt ∣ x) − sθ̂ (xt, t) pdata(x) [q0t (xt ∣ x)] 2 ] • NCSNのときと同様に、時刻 t におけるスコア関数は、デノイジングスコア マッチングで推定できる • 時刻 t は一様ランダムにサンプルする 𝔼 𝒰 𝔼 • 条件付き分布 q0t は、SDEの形(関数 f, g の選び方)によって変わる
x0 ∼ q0 (x0) = pdata (x0) VE SDE g (t) = Variance Exploding SDE 拡散過程 • f (xt, t) = 0 の場合 dxt = g (t) dw • このとき、条件付き分布がNCSNの場合と同じ形になる 𝒩 q0t (xt ∣ x0) = 2 (xt; x0, σt I) 2 dσt dt
x0 ∼ q0 (x0) = pdata (x0) t 1 αt = exp − βsds ( 2 ∫0 ) VP SDE Variance Preserving SDE σt = 拡散過程 1 dxt = − βt ⋅ xtdt + 2 • f (xt, t) = βt ⋅ xt, g (t) = βt dw βt の場合 • このとき、条件付き分布がDDPMの場合と同じ形になる 𝒩 q0t (xt ∣ x0) = 2 (xt; αtx0, σt I) 2 1 − αt
xt ∼ qt (xt) = pdata(x) [q0t (xt ∣ x)] ϵ ∼ (ϵ; 0, I), t ∼ (t; 0,1) ϵθ̂ (xt) = − σt ⋅ sθ̂ (xt) Score Matching JDSM (θ) = = λt [2 λt 2 2σ [ t ∇xt log q0t (xt ∣ x) − sθ̂ (xt, t) ϵ − ϵθ̂ (xt, t) 2 2 ] ] 𝔼 • VE・VPいずれの場合も、ノイズ予測の形でスコアマッチングができる 𝒰 𝔼 𝒩 𝔼 2 • λt = σt とすることが多い
サンプリング dxt = [f (xt, t) − g (t) sθ̂ (xt, t)] dt + g (t) dw 2 • 推定したスコア関数 sθ̂ を使って、逆拡散過程のSDEをシミュレートすれば、 データのサンプリングができる • SDEのシミュレーションは、ソルバーを使って数値計算できる
Euler–Maruyama method オイラー・丸山法 dxt = [f (xt, t) − g (t) sθ̂ (xt, t)] dt + g (t) dw 2 離散化 xt+Δt ∼ ̂ (t) (t) x ; x + f x , t − g s x , t Δt, g ( ) ( ) t+Δt t t θ t ( [ ] 2 • 最も基本的なSDEのソルバー 𝒩 • Δt を十分小さくすれば、離散化誤差を抑えられる 2 Δt )
Probability flow ODE 確率フローODE • SDEには、各時刻の確率分布 qt が同一になるような常微分方程式(ODE)が 存在することが知られている dxt = f (xt, t) dt + g (t) dw ODE化 SDE化 1 2 dxt = f (xt, t) − g (t) ∇xt log qt (xt) dt [ ] 2
Probability flow ODE ODEによるサンプリング SDEの代わりにODEをシミュレートすることでサンプリングができる オイラー法 1 2 dxt = f (xt, t) − g (t) sθ̂ (xt, t) dt [ ] 2 1 2 xt+Δt = xt + f (xt, t) − g (t) sθ̂ (xt, t) Δt [ ] 2 • SDEサンプラーよりも少ないステップ数で早く収束することが多い
1 2 hθ (xt, t) = f (xt, t) − g (t) sθ̂ (xt, t) 2 Probability flow ODE 尤度評価 1 2 dxt = f (xt, t) − g (t) sθ̂ (xt, t) dt [ ] 2 • ODEを使うと、モデルの尤度を評価することができる ODE log p0̂ (x0; θ) = log p1 (x1) + ∫ 1 0 tr ( ∇xt hθ (xt, t)) dt ODE は x1 ∼ p1 (x1) として、ODEによって生成される xt の分布 • pt̂
Score SDE Model RealNVP (Dinh et al., 2016) iResNet (Behrmann et al., 2019) Glow (Kingma & Dhariwal, 2018) MintNet (Song et al., 2019) Residual Flow (Chen et al., 2019) FFJORD (Grathwohl et al., 2018) Flow++ (Ho et al., 2019) DDPM (L) (Ho et al., 2020) DDPM (Lsimple ) (Ho et al., 2020) DDPM DDPM cont. (VP) DDPM cont. (sub-VP) DDPM++ cont. (VP) DDPM++ cont. (sub-VP) DDPM++ cont. (deep, VP) DDPM++ cont. (deep, sub-VP) NLL Test Ó FID Ó 3.49 3.45 3.35 3.32 3.28 3.40 3.29 § 3.70* § 3.75* 46.37 13.51 3.17 3.28 3.21 3.05 3.16 3.02 3.13 2.99 3.37 3.69 3.56 3.93 3.16 3.08 2.92 Model FIDÓ ISÒ 14.73 2.42 9.22 10.14 StyleGAN2-ADA (Karras et al., 2020) NCSN (Song & Ermon, 2019) NCSNv2 (Song & Ermon, 2020) DDPM (Ho et al., 2020) 2.92 25.32 10.87 3.17 9.83 8.87 ˘ .12 8.40 ˘ .07 9.46 ˘ .11 DDPM++ DDPM++ cont. (VP) DDPM++ cont. (sub-VP) DDPM++ cont. (deep, VP) DDPM++ cont. (deep, sub-VP) NCSN++ NCSN++ cont. (VE) NCSN++ cont. (deep, VE) 2.78 2.55 2.61 2.41 2.41 2.45 2.38 2.20 9.64 9.58 9.56 9.68 9.57 9.73 9.83 9.89 Conditional BigGAN (Brock et al., 2018) StyleGAN2-ADA (Karras et al., 2020) Unconditional
Score SDE まとめ • SDEを用いて定式化すると、DDPMとNCSNはSDEの形が異なるだけで、両者 は統一的に記述できる ‣ ロス関数も全く同じになる ‣ 両者を連続時間SDEとして記述したモデルをDDPM++, NCSN++という • ODEとの対応を用いることで、サンプリングを高速化したり、尤度を評価し たりできる
ScoreFlow [Song, et al., 2021] • DDPMのロスは、ELBO(=対数尤度の下界)と対応していた • DDPM++やNCSN++で使われるスコアマッチングロスも、尤度と対応する
x ∼ pdata (x) xt ∼ qt (xt) = [q0t (xt ∣ x)] ϵ ∼ (ϵ; 0, I), t ∼ (t; 0,1) ϵθ̂ (xt) = − σt ⋅ sθ̂ (xt) ScoreFlow [Song, et al., 2021] JDSM (θ) = λt 2 2σ [ t ϵ − ϵθ̂ (xt, t) 2 ] 2 • λt = g (t) のとき、JDSM は対数尤度の期待値の下界になる 𝒰 𝔼 𝒩 𝔼 𝔼 SDE log p x; θ ≥ − J θ + C ( ) ( ) DSM 0 [ ]
ScoreFlow [Song, et al., 2021] Model SDE Baseline Baseline + LW Baseline + LW + IS Deep Deep + LW Deep + LW + IS Baseline Baseline + LW Baseline + LW + IS Deep Deep + LW Deep + LW + IS VP VP VP VP VP VP subVP subVP subVP subVP subVP subVP CIFAR-10 Uni. deq. Var. deq. NLLÓ BoundÓ NLLÓ BoundÓ 3.16 3.28 3.04 3.14 3.06 3.18 2.94 3.03 2.95 3.08 2.83 2.94 3.13 3.25 3.01 3.10 3.06 3.17 2.93 3.02 2.93 3.06 2.80 2.92 2.99 3.09 2.88 2.98 2.97 3.07 2.86 2.96 2.94 3.05 2.84 2.94 2.96 3.06 2.85 2.95 2.95 3.05 2.85 2.94 2.90 3.02 2.81 2.90 FIDÓ 3.98 5.18 6.03 3.09 7.88 5.34 3.20 7.33 5.58 2.86 6.57 5.40 ImageNet 32ˆ32 Uni. deq. Var. deq. NLLÓ BoundÓ NLLÓ BoundÓ 3.90 3.96 3.84 3.91 3.91 3.96 3.86 3.92 3.86 3.92 3.80 3.88 3.89 3.95 3.84 3.90 3.91 3.96 3.86 3.92 3.85 3.92 3.79 3.88 3.87 3.92 3.82 3.88 3.87 3.92 3.82 3.88 3.84 3.91 3.79 3.87 3.86 3.91 3.81 3.87 3.88 3.93 3.83 3.88 3.82 3.90 3.76 3.86 FIDÓ 8.34 17.75 11.15 8.40 17.73 11.20 8.71 12.99 10.57 8.87 16.55 10.18
まとめ • 拡散モデルは、階層VAEとスコアベースモデルという2つの解釈ができる • 両者はSDEを通して数理的に密接に関係する ‣ 無限層VAEや連続的なノイズレベルを想定すると、両者は等価 ‣ 同様にELBOとスコアマッチングロスも一定の条件下で等価
拡散モデルの数理 ④ サンプラーを理解する 谷口尚平 2024年6月19日
サンプラー • 拡散モデルのサンプラーにはたくさんの 種類がある • Stable Diffusionでは19個
サンプリング 拡散モデルにおけるサンプリング 逆拡散過程SDE/ODEのシミュレーション SDE: dxt = [f (xt, t) − g (t) sθ̂ (xt, t)] dt + g (t) dw 2 1 2 ODE: dxt = f (xt, t) − g (t) sθ̂ (xt, t) dt [ ] 2
Euler–Maruyama method オイラー・丸山法 dxt = [f (xt, t) − g (t) sθ̂ (xt, t)] dt + g (t) dw 2 離散化 xt+Δt ∼ ̂ (t) (t) x ; x + f x , t − g s x , t Δt, g ( ) ( ) t+Δt t t θ t ( [ ] 𝒩 • 最も基本的なSDEの1次ソルバー 2 2 Δt )
𝒪 Euler method オイラー法 1 2 dxt = f (xt, t) − g (t) sθ̂ (xt, t) dt [ ] 2 1 2 xt+Δt = xt + f (xt, t) − g (t) sθ̂ (xt, t) Δt [ ] 2 • 最も基本的なODEの1次ソルバー 2 • 離散化誤差は ((Δt) )
改善の方向性 • 2次以上のソルバーを使う • Heun • Linear multi-step • 解析的に解ける部分をうまく使う • DPM • DPM++
𝒪 Heun 2次のODEソルバー dxt = h (xt, t) dt x̃t+Δt = xt + h (xt, t) Δt 1 xt+Δt = xt + (h (xt, t) + h (x̃t, t)) Δt 2 • EDMなどで使われる 3 • Euler法の2倍の計算量がかかるが、離散化誤差が ((Δt) )
LMS Linear multi-step method dxt = h (xt, t) dt 3 1 xt+Δt = xt + h (xt, t) − h (xt−Δt, t − Δt) Δt (2 ) 2 • 現在の勾配だけでなく、過去の履歴も使って更新する • Heunと違って、Eulerと同じ計算量で済む(メモリは2倍かかる) • 3ステップ以上の履歴を使うこともできる
改善の方向性 • 2次以上のソルバーを使う • Heun • Linear multi-step • 解析的に解ける部分をうまく使う • DPM • DPM++
DPM 2 g (t) dxt = f (t) xt − ϵθ̂ (xt, t) dt 2σt [ ] • 第1項は xt に関して線形なので、解析的に解けることを用いると xt+Δt = αt+Δt αt xt − αt λt+Δt ∫λ t −λ ̂ e ϵθ (xtλ, tλ) dλ
DPM-Solver-1 1次近似する場合 xt+Δt = ≈ αt+Δt αt αt+Δt αt xt − αt λt+Δt ∫λ t −λ ̂ e ϵθ (xtλ, tλ) dλ xt − σt+Δt (e λt+Δt − 1) ϵθ̂ (xt, t) • 第1項の誤差がない分、正確にシミュレートできる • 1次近似の場合は、DDIMのサンプラーと等価 • Heun法と同じように2次近似にも拡張できる
DPM++ • 少ないステップ数でもFIDがあまり 劣化しない
x̂θ = (xt − σt ϵθ̂ ) / αt DPM++ • ODEをノイズ予測 ϵθ̂ の代わりにシグナル予測 x̂θ で書き換えると 2 g (t) dxt = f (t) xt − ϵθ̂ (xt, t) dt 2σt [ ] 2 2 αtg (t) g (t) ̂ = f (t) + x − x x , t dt ( ) t θ t 2 2 2σt ) 2σt ( • これも第1項が xt に関して線形なので、解析的に解ける
DPM-Solver++1 1次近似する場合 xt+Δt = ≈ σt+Δt σt σt+Δt σt xt − σt λt+Δt ∫λ t e xθ (xtλ, tλ) dλ xt − αt+Δt (e • DPMと同様に2次以上にも拡張できる λ ̂ −λt+Δt − 1) x̂θ (xt, t)
x̂θ = (xt − σt ϵθ̂ ) / αt SDE-DPM-Solver++1 (zt; 0, I) zt ∼ • ODEの場合と同様のことがSDEの場合にもできる λt+Δt σt+Δt λ −λ xt+Δt = e t t+Δtxt + 2αt ∫λ σt t e xθ (xtλ, tλ) dλ + 2(λ − λt+Δt) ̂ σt+Δt λ −λ 2(λt+Δt − λt) t t+Δt ̂ ≈ e xt + αt+Δt (1 − e x x , t + σ ( ) θ t t ) σt 𝒩 • これも2次以上の場合に拡張できる 2σt λt+Δt ∫λ e λ−λt+Δt dwλ t 1−e 2(λt+Δt − λt) zt
DPM++ • DPM以上に性能が良くなる
まとめ • 拡散モデルはODE/SDEをシミュレートすることでサンプリングする • オイラー法は最も素朴な離散化方法 • Heunなどの高次のソルバーを使ったり、一部の項を解析的に計算すること で、性能を向上できる