【DL輪読会】Conditional Flow Matching

12.4K Views

April 18, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] Conditional Flow Matching Presenter: Masahiro Suzuki, Matsuo Iwasawa Lab 2024/04/18 http://deeplearning.jp/ 1

2.

本発表情報 ¤ Conditional Flow Matchingについてのまとめ ¤ Stable Diffusion 3のテクニカルレポート[Esser+ 24]でもconditional flow matchingの⽂脈で説明されている. ¤ 各研究論⽂のほかに,主に以下の論⽂や記事を参照. ¤ Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport ¤ AN INTRODUCTION TO FLOW MATCHING ¤ https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html ¤ 選定理由 ¤ この辺りの研究がいろいろあってよくわからなくなってきたので,まとめようと思った. ¤ 謝罪 ¤ 以前の発表とかぶっていたことに気づきました・・・ ¤ https://www.slideshare.net/DeepLearningJP2016/flow-matching-for-generative-modeling ¤ https://www.docswell.com/s/DeepLearning2023/K4Q9ED-2023-09-15-114701 2

3.

背景︓確率微分⽅程式 ¤ 以下の確率微分⽅程式(stochastic differential equation︔SDE)を考える. ¤ 伊藤のSDEと呼ばれる. ¤ 𝑓はドリフト係数, 𝑔は拡散係数, d𝑤は標準ウィーナー過程(𝑑𝑤~𝒩(0, 𝑑𝑡)). 𝑑𝑥 = 𝑓(𝑥, 𝑡)𝑑𝑡 + 𝑔(𝑡)𝑑𝑤 ¤ SDEを解くことで拡散過程 𝑥! !∈[$,&] が得られる. 3

4.

背景︓逆時間確率微分⽅程式 ¤ 前ページのSDEの逆過程(逆時間SDE)は次のようになる[Anderson+ 82]. 𝑑𝑥 = 𝑓(𝑥, 𝑡) − 𝑔(𝑡).𝛻/ log 𝑝(𝑥0 ) 𝑑𝑡 + 𝑔(𝑡)𝑑𝑤 2 逆時間SDEの導出は次のサイトを参照︓ https://ludwigwinkler.github.io/blog/ReverseTimeAnderson/ ¤ 前ページのSDEと同じ拡散過程 𝑥! !∈[$,&] が得られる. ¤ スコア𝛻( log 𝑝! (𝑥)を推定できれば,上記の逆時間SDEで事前分布からデータ分布のサンプルが得ら れる. 4

5.
[beta]
背景︓スコアマッチング

¤ 以下の⽬的関数の最適化によってスコアを推定するスコアベースモデル 𝑠! (𝑥, 𝑡) を学習する[Song+ 20] .
¤ ノイズ除去スコアマッチング[Vincent 11]の⽬的関数.

1 &
+ 𝔼
𝜆(𝑡) 𝛻($ log 𝑝(𝑥! |𝑥$ ) − 𝑠- 𝑥! , 𝑡
2 $ )(($))(($|(%)

.
.

𝑑𝑡

¤ 𝜆: 0,1 → ℝ!" は重みづけ関数.
¤ 遷移分布𝑝(𝑥# |𝑥" )は(ドリフト係数𝑓が𝑥について線型ならば)ガウス分布となるので,陽に計算できる.

¤ SDEの形によって,⽬的関数(遷移分布)が変わる.
¤ Variance Exploding (VE) SDE︓連続時間のスコアベース(score matching with Langevin dynamics︔SMLD
[Song+ 19] )

𝑑𝑥 = 𝜎 t 𝑑𝑤 ⇒ 𝑝 𝑥! 𝑥$ = 𝒩 𝑥( ; 𝑥$ , 𝜎 ) (𝑡) − 𝜎 ) (0) 𝐼
¤ Variance Preserving (VP) SDE︓連続時間の拡散モデル(denoising diffusion probabilistic modeling︔DDPM
[Ho+ 20] )
& "
"
1
* ∫! ,(.)0.
*∫! ,(.)0.
)
𝑑𝑥 = − 𝛽 𝑡 𝑥𝑑𝑡 + 𝛽 𝑡 𝑑𝑤 ⇒ 𝑝 𝑥! 𝑥$ = 𝒩 𝑥! ; 𝑥$ 𝑒
, 𝐼 − 𝐼𝑒
2

5

6.

背景︓スコアマッチング ¤ スコアマッチングの⽬的関数(再掲) 1 & + 𝔼 𝜆(𝑡) 𝛻($ log 𝑝(𝑥! |𝑥$ ) − 𝑠- 𝑥! , 𝑡 2 $ )(($))(($|(%) . . 𝑑𝑡 ¤ (𝜆(𝑡)を適切に設定することで)対数尤度の下界を最⼤化することに対応する[Song+ 21, Kingma+ 21] ¤ スコアマッチングの特徴︓ ¤ ⽬的関数の評価にSDEのシミュレーションが不要(simulation-free)なので,学習が効率的. ¤ ただし推論時にはSDEシミュレーションが必要であるため,推論効率を向上させる⼯夫が必要(蒸留な ど[Salimans+ 22]). 6

7.

背景︓正規化フロー ¤ 正規化フロー[Rezende+ 15] ¤ 観測変数𝑥! が,シンプルな確率密度関数𝑝" (𝑥" )と可逆で微分可能な変換(フロー)𝜙 # によって次のよ うに書けるとする. 𝑥/ = 𝜙/ 𝑥$ , 𝑥$ ~𝑝$ (𝑥$ ) ¤ このとき確率密度関数𝑝! 𝑥! は次のように書ける(変数変換の公式またはpush-forward equation) 𝑑𝑥" 𝑑𝜙!&% (𝑥! ) &% 𝑝! 𝑥! = 𝑝" 𝑥" det = 𝑝! 𝜙! (𝑥! ) ≡ [𝜙! ] ∗ 𝑝" 𝑑𝑥! 𝑑𝑥! ¤ データ分布について,これを最⼤化するような𝜙を求めれば良い. ¤ 最尤推定のためには,変換𝜙は可逆かつ逆関数とヤコビアンが簡単に計算できないといけない. ¤ そこで単純な𝜙を合成することで複雑な分布を表現する(正規化フロー) 𝑝! 𝑥$ = [𝜙! ∘ ⋯ ∘ 𝜙% ] ∗ 𝑝" 𝑥! ・・・ 𝑥" ・・・ 𝑥# 7

8.

背景︓Residual flow ¤ ヤコビアンを効率的に計算するために,様々な⼯夫がある. ¤ 詳細は以前の輪読資料を参照. ¤ https://www.slideshare.net/DeepLearningJP2016/dlflowbased-deep-generative-models ¤ Residual Flow[Chen+ 20] ¤ 以下のフローにおいて(𝑢' が1/𝛿リプシッツのとき)対数尤度の不偏推定量が得られる. ¤ このときヤコビアンはフルランクとなる. 𝑥J = 𝜙J 𝑥 = 𝑥 + 𝛿𝑢J 𝑥 ¤ 式変形すると, 𝜙0 𝑥 − 𝑥 = 𝑢0 𝑥 𝛿 8

9.

背景︓連続正規化フロー ¤ 𝛿 = 1/𝐾として𝐾 → ∞とすると,前ページの式は次のような常微分⽅程式(ordinary differential equation︔ODE)の形になる(continuous normalizing flow︔CNF [Chen+ 18, Lipman+ 23]) 𝑑𝜙# 𝑥" = 𝑢# 𝜙# 𝑥" 𝑑𝑡 ¤ ただし𝑡 ∈ [0,1] ¤ 𝑢# はベクトル場,𝜙# 𝑥" = 𝑥# はフロー ¤ シンプルに書くと &'$ &# = 𝑢# 𝑥# ¤ 連続の⽅程式(流体における保存則)によって,時刻𝑡の𝑥# における対数確率密度log 𝑝# 𝑥# (𝑝# は確 率パス)の変化は次のように計算できる. 𝑑 log 𝑝# 𝑥# = −div 𝑢# (𝑥# ) 𝑑𝑡 ¤ 任意のODEソルバーで𝑥# およびlog 𝑝# 𝑥# を計算できる. 9

10.

背景︓連続正規化フロー ¤ 連続正規化フローのイメージ(𝑝$ から𝑝& ) ¤ 左︓各時刻の確率パス ¤ 中︓各時刻のベクトル場 ¤ 右︓サンプルの軌跡 https://github.com/atong01/conditional-flow-matching 10

11.

背景︓SDEとprobability flow ODE ¤ 全てのSDEには周辺確率パス{𝑝! (𝑥)}が⼀致する次のprobability flow ODE [Song+ 21]がある. 1 𝑑𝑥 = 𝑓(𝑥, 𝑡) − 𝑔(𝑡)( 𝛻) log 𝑝$ (𝑥) 𝑑𝑡 2 ( ¤ 連続正規化フローのベクトル場が𝑓(𝑥, 𝑡) − ) 𝑔(𝑡))𝛻' log 𝑝# (𝑥)となることに対応. 11

12.

連続正規化フローの学習の課題 ¤ 連続正規化フローはODEの形になっているので,ランダム性を含むSDEより効率的に𝑥! のシ ミュレーション(推論)を計算できる. ¤ ⼀⽅𝑢! について最尤学習するためには,毎回ODEソルバーで対数尤度を計算する必要がある. ¤ 以下の式の積分の部分. % log 𝑝% (𝑥) = log 𝑝" 𝑥" − O ∇ ⋅ 𝑢$ 𝑥$ d𝑡 " ¤ ここまでのまとめ ¤ スコアベースモデルは,学習が効率的(simulation-free)だが,推論に計算コストがかかる. ¤ 連続正規化フローは,推論は効率的だが,学習コストがかかる. => 連続正規化フローをスコアベースモデルと同様のsimulation-freeな⽅法で学習できないか︖ 12

13.

フローマッチング ¤ ベクトル場𝑢! (𝑥)をニューラルネットワーク𝑣- (𝑥, 𝑡)で近似することを考え,以下の⽬的関数を 導⼊する. 𝔼0∼𝒰(T,U),/∼V$ (/) 𝑣W (𝑡, 𝑥) − 𝑢0 (𝑥) . ¤ スコアマッチングと同様に,回帰にすることでsimulation-freeな⽬的関数となる. ¤ ⽬的関数が0になれば, 𝑝% は真の分布𝑞を近似する. ¤ この⽬的関数による学習をフローマッチングという[Lipman+ 23]. ※ 正確にはフローではなくベクトル場をマッチングしている. ¤ フローマッチングの課題 ¤ 𝑢$ (𝑥)を求める必要がある(そもそも𝑢$ (𝑥) がわかっていれば近似する必要はない) ¤ 𝑝% ≈ 𝑞を実現する有効な確率パス(およびベクトル場)はいくらでも存在し,どの𝑝* や𝑢$ が適切なの かは明らかではない. 13

14.

ベクトル場の任意性 ¤ 𝑝$ から同じ𝑝& に到達する場合でも,様々なベクトル場が考えられる. https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html 14

15.
[beta]
条件付き確率パス
¤ データサンプル𝑥& ~𝑞(𝑥& ) で条件付けた確率パス𝑝! 𝑥 𝑥& を考える.
2
¤ ただし𝑝" 𝑥 𝑥( = 𝑝"(𝑥)および𝑝1 𝑥 ∣ 𝑥1 = 𝒩 𝑥; 𝑥1 , 𝜎min
𝐼

⟶ 𝛿𝑥1 (𝑥)

𝜎min →0

¤ 𝑡 = 0で元の分布, 𝑡 = 1でデータサンプル(にノイズが少し⼊った分布)になる.

https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html

¤ 上記の条件付き確率パスを⽤いて,周辺確率パスは次のようになる.

𝑝0 𝑥 = ∫ 𝑞 𝑥U 𝑝0 𝑥 𝑥U 𝑑𝑥U

¤ 𝑡 = 1で𝑝% はデータ分布𝑞を近似する.

15

16.

条件付きフローマッチング ¤ 同様に,ベクトル場を条件付きベクトル場𝑢! 𝑥 ∣ 𝑥& によって表す. 𝑢$ (𝑥) = O 𝑢$ 𝑥 ∣ 𝑥% 𝑝$ 𝑥 ∣ 𝑥% 𝑞 𝑥% 𝑑𝑥% 𝑝$ (𝑥) ¤ この周辺ベクトル場は,前ページの周辺確率パスと対応していることが証明されている. ¤ したがって,次の⽬的関数を導⼊する(条件付きフローマッチング). 𝔼0,\ /% ,V$ /∣/% 𝑣W (𝑡, 𝑥) − 𝑢0 𝑥 ∣ 𝑥U . ¤ この勾配は,フローマッチングの⽬的関数の勾配と等しくなることが証明されている. ¤ 𝑢$ 𝑥 ∣ 𝑥% や𝑝$ 𝑥 ∣ 𝑥% は容易に計算できるので(次ページで説明),この⽬的関数の最適化も効率的に 実⾏できる. 16

17.

条件付きフローマッチング ¤ 条件付き確率パスを次のようにガウス分布で定義する. 𝑝0 𝑥 ∣ 𝑥U = 𝒩 𝑥; 𝜇0 𝑥U , 𝜎0 𝑥U .I ¤ このとき,フローは𝜙$ 𝑥 ∣ 𝑥% = 𝜇$ 𝑥% + 𝜎$ 𝑥% 𝑥 ¤ 対応するベクトル場の1つは次のように計算できる. 𝜎!1 𝑥& 𝑢! 𝑥 ∣ 𝑥& = 𝑥 − 𝜇! 𝑥& 𝜎! 𝑥& + 𝜇!1 𝑥& ¤ 𝜇! や𝜎! の選択について︓ ¤ スコアマッチングにおける遷移分布とみなせば,VPやVEと同じ形で置ける. ¤ ODEなので遷移に確率的要素がない(probability flow ODE [Song+ 21] ) ¤ フローマッチングでは𝑝" がノイズ, 𝑝% がデータと逆になっていることに注意. ¤ 他にも,よりシンプルな選択が考えられる. 17

18.

最適輸送に対応した条件付けベクトル場 ¤ 前述のように,条件付き確率パス𝑝! 𝑥 ∣ 𝑥& = 𝒩 𝑥; 𝜇! 𝑥& , 𝜎! 𝑥& . I は以下を満たせばいい. ¤ 𝑡 = 0で𝑝" 𝑥 𝑥% = 𝑝" 𝑥 ( ¤ 𝑡 = 1で𝑝% 𝑥 ∣ 𝑥% = 𝒩 𝑥; 𝑥% , 𝜎456 𝐼 ⟶ 7'() →" 𝛿)* (𝑥) ¤ 𝑝$ 𝑥 = 𝒩 𝑥; 0, 𝐼 とすると,最も単純なのは次のような線型補完の形. 𝜇$ (𝑥) = 𝑡𝑥% , 𝜎$ (𝑥) = 1 − 1 − 𝜎456 𝑡 ¤ フローは𝜙$ 𝑥 ∣ 𝑥% = 𝑡𝑥% + 1 − 1 − 𝜎456 𝑡 𝑥 ¤ 対応するベクトル場の⼀つは𝑢$ 𝑥 ∣ 𝑥% = )*& %&7'() ) %& %&7'() $ ¤ この条件付けフローは𝑝$ (𝑥|𝑥& )から𝑝& (𝑥|𝑥& )への最適輸送の移送補間に対応している. ¤ ただし「条件付き」確率パスの最適輸送であることに注意 18

19.

最適輸送によるフローマッチングの有効性 ¤ VEやVPを選択した場合と⽐べると,最適輸送の場合は⽅向を変えずにターゲットに向かっている. ¤ VEやVPの場合(図ではdiffusionと表記)は,遠回りした軌道を描いている. ¤ 2Dのチェッカーボードへの遷移 ¤ 最適輸送によるフローマッチングが⼀番ターゲットに到達するのが早い. ¤ 少ないシミュレーション数(NFE,オイラー法を利⽤)で⽬標のサンプルが実現できている. 19

20.

条件付きフローマッチングの課題 ¤ 条件付きフローマッチングでは,任意のサンプルから特定のサンプルへの条件付きベクトル場を考え ていた(one-sided conditioning) ¤ しかし,その場合真の周辺ベクトル場と異なる軌跡になる可能性がある. ¤ 周辺パスはクロスしていない(ODEの性質)にも関わらず,条件付きパスではクロスしてしまっている.その ためベクトル場の分散が⼤きくなり収束が遅くなる. ¤ 実際には右の分布に遷移して欲しいにも関わらず,条件付きパスではクロスして別の分布に遷移している.周 辺パスも変に曲がってしまっていて(=最適輸送ではない),推論の際にサンプリングが遅くなる. 20

21.

Two-sided conditioning ¤ 前ページの問題を解決するために,両端(𝑡 = 0,1)のサンプルで条件づける(two-sided conditioning) 𝑝0 𝑥0 = ∫ 𝑝0 𝑥0 ∣ 𝑥U, 𝑥T 𝑞 𝑥U, 𝑥T d𝑥Ud𝑥T ¤ ただし𝑝" 𝑥 ∣ 𝑥% , 𝑥" = 𝛿)! および𝑝% 𝑥 ∣ 𝑥% , 𝑥" = 𝛿)* とする(両端でサンプルになる) ¤ 𝑞 𝑥& , 𝑥$ = 𝑞 𝑥& 𝑞(𝑥$ )として, 𝜇! 𝑥$ , 𝑥& , 𝜎! (𝑥$ , 𝑥& )を設定する. ¤ Rectified Flow[Liu+ 22]︓ ¤ 𝜇$ 𝑥" , 𝑥% = 𝑡𝑥% + 1 − 𝑡 𝑥" ,𝜎$ 𝑥" , 𝑥% = 0 ¤ Variance preserving stochastic interpolant [Albergo+ 23](Cosine [Nichol+ 21])︓ ¤ 𝜇$ 𝑥" , 𝑥% = cos % ( 𝜋𝑡 𝑥" + sin % ( 𝜋𝑡 𝑥% ,𝜎$ 𝑥" , 𝑥% = 0 ¤ Independent CFM(I-CFM)[Tong+ 24]︓ ¤ 𝜇$ 𝑥" , 𝑥% = 𝑡𝑥% + 1 − 𝑡 𝑥" ,𝜎$ 𝑥" , 𝑥% = 𝜎 21

22.

Two-sided conditioning ¤ 2つの任意の異なる分布間を補間することができる. https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html ¤ 課題︓ ¤ 𝑞 𝑥% , 𝑥" = 𝑞 𝑥% 𝑞(𝑥" )のように異なる分布を独⽴としているため,前述のクロスの問題などは残る (周辺確率フローでの最適輸送にはならない) 22

23.

最適輸送カップリング ¤ 𝑞 𝑥& , 𝑥$ の独⽴の仮定を除き,次のように設定する(最適輸送カップリング) 𝑞 𝑥U, 𝑥T = 𝜋 𝑥U, 𝑥T = arg infc∈e ∫ 𝑥U − 𝑥T ¤ 2-Wasserstein距離を最⼩化する輸送計画𝜋 𝑥% , 𝑥" とする. . . d𝜋 𝑥U, 𝑥T ¤ 最適輸送カップリングによってクロスする問題が解消される. https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html 23

24.

ミニバッチOT-CFM ¤ ⼤規模なデータセットに対して最適輸送カップリングを計算するのは困難なので,ミニバッチ に対して最適輸送を計算することで近似をする. ¤ ミニバッチが⼩さければ線型計画として任意のソルバで解ける. ¤ 最適輸送カップリングを利⽤する以外は,two-sided conditioning(I-CFM)と同じ. 24

25.

ミニバッチOT-CFM ¤ 右︓クロスの問題が解決し,周辺ベクトル場が最適輸送になっている. https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html 25

26.

Schrödinger bridge CFM ¤ Schrödinger bridge problem[Chen+21] ¤ 両端の分布が与えられたときに,ある確率過程がこれらに従うような条件付き確率過程を求める問題 ¤ 𝑝89: を標準ウィーナー過程として,以下を求める. 𝜋 ∗: = arg min 3 (% 45 (% ,3 (* 45 (* KL 𝜋 ∥ 𝑝ref ¤ エントロピー正則化付き最適輸送問題の解(ただし𝜆 = 2𝜎 . )で同時分布を定義する. 𝑞 𝑥& , 𝑥$ = 𝜋 𝑥& , 𝑥$ = arg inf3∈6 ∫ 𝑥& − 𝑥$ .. d𝜋 𝑥& , 𝑥$ − 𝜆KL 𝜋 ∥ 𝑞$ ⊗ 𝑞& ¤ 条件付き確率パスを次のように設定すると,周辺確率パスとSchrödinger bridge problemの 解𝜋 ∗が⼀致する[Tong+ 24]. 𝑝! (𝑥 ∣ 𝑥& , 𝑥$ ) = 𝒩 𝑥 ∣ 𝑡𝑥& + (1 − 𝑡)𝑥$ , 𝑡(1 − 𝑡)𝜎 . ¤ スコアマッチングと合わせて確率過程で学習する⽅法も提案している[Tong+ 24] 26

27.

それぞれの条件付き確率パスの違い [Tong+ 24] 27

28.

⽐較結果 ¤ CIFAR10での学習結果 ¤ 最適輸送のアプローチが最も⾼い結果 ¤ Adaptiveは適応的なソルバ(dopri5)を利⽤した結果 28

29.

⽐較結果 https://github.com/atong01/conditional-flow-matching Action-Matching[Neklyudov+ 23]︓ • two-sided conditioningの1つだが,ベクトル場の勾配(アクション)をマッチングしている. 29

30.

Stable Diffusion 3 ¤ Conditional Flow Matchingで学習[Esser+ 24] ¤ 条件付きフローはRectified Flowを利⽤. ¤ 𝑡 ∈ [0,1] のサンプリングを⼀様分布ではなく,特定の分布に従ってサンプリングする. ようにする. ¤ [0,1]の中間地点の学習が難しいので,より多くサンプリングされるようにする. ¤ Logit-Normal分布を利⽤︓ % % ; (< $(%&$) exp − (?@A5*($)&B)+ (; + ¤ ほかにもMode Sampling with heavy Tails,CosMapで検証. ¤ アーキテクチャ的には,バックボーンをUNetからDiffusion Transformer(DiT) に変更. ¤ 事前学習したautoencoderの潜在変数上で学習する点はLDMと同じにする. ¤ その他,Multi-modal DiT(それぞれのモダリティで学習する重みを分ける)を使う等 の⼯夫をしている. 30

31.

Stable Diffusion 3 ¤ 学習︓ImageNet,CC12M,テスト︓COCO-2014 ¤ Rectified Flow + Logit-Normal Sampling が平均して良い結果(左).ステップ数が少ない回数で良 い結果(サンプル効率がいい)となっている(右). 31

32.

Stable Diffusion 3 32

33.

まとめ ¤ Conditional Flow Matchingを中⼼に関連研究をまとめた. ¤ SDE->スコアマッチング,ODE->フローマッチング ¤ One-sided conditioning ¤ Two-sided conditioning(Rectified Flowなど) ¤ 最適輸送カップリング ¤ Stable Diffusion 3 33