【DL輪読会】SUMO: Unbiased Estimation of Log Marginal Probability for Latent Variable Models

>100 Views

August 20, 21

スライド概要

2021/08/20
Deep Learning JP:
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
1.

SUMO: Unbiased Estimation of Log Marginal Probability for Latent Variable Models Shohei Taniguchi, Matsuo Lab 1

2.

ॻࢽ৘ใ SUMO: Unbiased Estimation of Log Marginal Probability for Latent Variable Models Yucen Luo, Alex Beatson, Mohammad Norouzi, Jun Zhu, David Duvenaud, Ryan P. Adams, Ricky T. Q. Chen https://arxiv.org/abs/2004.00353 ICLR 2020 accepted (஫) Θ͔Γ΍͢͞ͷͨΊʹɺൃදͰ͸࿦จதͱ‫ه‬๏͕ҟͳΔ෦෼͕͋Γ·͢

3.

֓ཁ જࡏม਺Ϟσϧͷର਺पล໬౓ͷෆภਪఆ • VAEͳͲͷજࡏม਺Ϟσϧͷֶश͸ҎԼͷର਺पล໬౓ͷ࠷େԽͰߦΘΕΔ log pθ (x) = log pθ (x, z) dz ∫ • ର਺पล໬౓͸௨ৗ‫͍ͳ͖Ͱࢉܭ‬ͷͰɺVAEͰ͸୅ΘΓʹͦͷԼքΛ࠷େԽ • ຊ࿦จͰ͸ɺԼքͰ͸ͳ͘ର਺पล໬౓Λ௚઀࠷େԽ͢Δํ๏ΛఏҊ • Russian roulette estimatorΛ࢖͏ςΫχοΫ͕໘ന͍

4.

Outline 1. ෆภਪఆྔ 2. જࡏม਺Ϟσϧ (VAE, IWAE) 3. Stochastically Unbiased Marginalization Objective (SUMO) • Russian roulette estimator • ෼ࢄ௿‫ݮ‬ • ࣮‫ݧ‬

5.

ෆภਪఆྔ

6.

𝔼 ෆภਪఆྔ Unbiased Estimator ਪఆ͍ͨ͠ྔɿyɹਪఆྔɿŷ [y]̂ = y ͕੒Γཱͭͱ͖ʮŷ ͸ y ͷෆภਪఆྔͰ͋Δʯͱ͍͏

7.

Unbiased Estimator Ex. 1: ਖ਼‫ن‬෼෍ͷ฼ฏ‫ۉ‬ͷਪఆ (1) , …, x (n) αϯϓϧฏ‫ ۉ‬x̄ = (x; μ, σ ) ͔ΒಘΒΕͨͱ͢Δ 2 ͕ਖ਼‫ن‬෼෍ n (i) ∑i=1 x n [x̄] = [ 𝔼 σʔλ x 𝔼 𝔼 𝒩 ෆภਪఆྔ ͸฼ฏ‫ ۉ‬μ ͷෆภਪఆྔ n (i) ∑i=1 x n n 1 (i) = x ]=μ [ ∑ n ] i=1

8.

ෆภਪఆྔ Unbiased Estimator Ex. 2: ϛχόονֶश n 1 (i) ‫ݧܦ‬ଛࣦ Ln = l (x , θ) n∑ i=1 m 1 (i) ̂ ϛχόον౷‫ ྔܭ‬Lm = l (x̃ , θ) ͸ Ln ͷෆภਪఆྔ m∑ i=1 (1) x̃ ͸ x , …, x (n) ͔ΒҰ༷ϥϯμϜʹબ͹Εͨ΋ͷ

9.

ෆภਪఆྔ Unbiased Estimator Ex. 3: Reparameterization trick (x) (x) Λਪఆ͍ͨ͠ f p = [ ] pθ(x) ( θ ∇θ (ϵ; 0,1) ͱͯ͠ ∇θ f (μθ + ϵ ⋅ σθ) ͸ ∇θ ϵ∼ ϵ∼ (ϵ; 0,1) [ ∇θ 𝔼 𝒩 𝒩 𝒩 𝔼 𝔼 pθ(x) [ f (x)] ͷෆภਪఆྔ f (μθ + ϵ ⋅ σθ)] = ∇θ • VAEͷΤϯίʔμͷޯ഑ਪఆʹ࢖ΘΕΔ 𝔼 (x; μθ, σθ)) pθ(x) [ f (x)]

10.

Unbiased Estimator Ex. 4: Likelihood ratio gradient estimator (REINFORCE) ∇θ (x) f Λਪఆ͍ͨ͠ [ ] pθ(x) x ∼ pθ (x) ͱͯ͠ f (x) ∇θ log pθ (x) ͸ ∇θ pθ(x) [ f (x)] ͷෆภਪఆྔ (x) (x) f ∇ log p = ∇ [ ] pθ(x) θ θ θ 
 𝔼 𝔼 • ‫ڧ‬Խֶशͷํࡦޯ഑๏Ͱ࢖ΘΕΔ 𝔼 𝔼 ෆภਪఆྔ (x) f [ ] pθ(x)

11.

ෆภਪఆྔ ͳͥෆภੑ͕ॏཁ͔ • ‫ػ‬ցֶशͰ͸ɺ‫ݧܦ‬ଛࣦ͕࠷খʹͳΔύϥϝʔλΛޯ഑๏౳Ͱ୳͢ • ‫ݧܦ‬ଛࣦ͕‫͍ͳ͖Ͱࢉܭ‬৔߹Ͱ΋ɺ‫ݧܦ‬ଛࣦͷෆภਪఆྔ͕‫͖Ͱࢉܭ‬Ε͹ ‫ॴہ‬ղ΁ͷऩଋ͕อূͰ͖Δ৔߹͕ଟ͍ e.g., ֬཰తޯ഑߱Լ๏ (stochastic gradient descent, SGD) ϛχόονਪఆྔΛ༻͍ͨޯ഑๏͸ɺద੾ʹֶश཰Λεέδϡʔϧ͢Ε͹ ‫ॴہ‬ղ΁ͷऩଋ͕อূ͞Ε͍ͯΔ

12.

༗ޮਪఆྔ Efficient Estimator ਪఆྔ͕ෆภੑΛ΋͍ͬͯͯ΋ɺ෼ࢄ͕େ͖͍ͱ҆ఆͨ͠ਪఆ͕Ͱ͖ͳ͍ Ex. 1: SGDͰόοναΠζ͕খ͍͞ͱ෼ࢄ͕େ͖͘ͳΓֶश͕҆ఆ͠ͳ͍ Ex. 2: Reparameterization trick͸Ұൠʹlikelihood ratio estimatorΑΓ௿෼ࢄ ཧ૝తͳਪఆྔ͸ ෆภਪఆྔ͔ͭ෼ࢄ͕খ͍͞΋ͷ ෆภਪఆྔͷதͰ෼ࢄ͕࠷খͱͳΔ΋ͷΛಛʹ༗ޮਪఆྔͱ͍͏

13.

જࡏม਺Ϟσϧ

14.

જࡏม਺Ϟσϧ Latent Variable Models ੜ੒ϞσϧͰΑ͘࢖ΘΕΔϞσϧ pθ (x) = pθ (x, z) dz ∫ ύϥϝʔλ θ ͷֶश͸ɺର਺पล໬౓ͷ࠷େԽͰߦ͏ log pθ (x) = log pθ (x, z) dz ∫

15.

ม෼ਪ࿦ Variational Inference ର਺पล໬౓͸௚઀‫͍ͳ͖Ͱࢉܭ‬ͷͰɺม෼ԼքΛ༻͍Δ log pθ (x) = log pθ (x, z) dz ∫ 𝔼 pθ (x, z) ≥ q(z) log = ℒ (θ, q) [ q (z) ] ͜ͷෆ౳ࣜ͸ q (z) = pθ (z ∣ x) ͷͱ͖౳߸੒ཱ ➡ q (z) Λ pθ (z ∣ x) ʹͳΔ΂͍ۙ͘෼෍͔Βબ΂͹ྑ͍

16.

ม෼ࣗ‫߸ූݾ‬Խ‫ث‬ Variational Autoencoder, VAE q (z) ʹ΋ύϥϝʔλΛ΋ͨͤͯ qϕ (z ∣ x) ͱͯ͠ಉ࣌ʹֶश͢Δ ໨తؔ਺͸ pθ (z ∣ x)ͱͷKL divergenceͷ࠷খԽ KL (qϕ (z ∣ x) ∥ pθ (z ∣ x)) = log pθ (x) − ℒ (θ, qϕ) ୈ1߲͸ ϕ ʹґଘ͠ͳ͍ͷͰɺ݁‫ ہ‬θ ͱ ϕ ͸ͱ΋ʹ ℒ ͷ࠷େԽͰֶशͰ͖Δ ϕ ͷޯ഑ͷਪఆʹ͸ɺઌड़ͷreparameterization trickΛ࢖͏

17.

VAEͷ՝୊ • ੜ੒Ϟσϧͷύϥϝʔλ θ ͷֶश͸ɺৗʹ qϕ (z ∣ x) ʹґଘ͢Δ • qϕ (z ∣ x) ͕ pθ (z ∣ x) ͔Β཭Ε͍ͯΔͱԼք͕؇͘ͳΓɺຊདྷ࠷େԽ͍ͨ͠ ର਺पล໬౓͔Β͔͚཭Εͨ΋ͷΛ࠷େԽͯ͠͠·͏ https://tips-memo.com/python-emalgorithm-gmm

18.

VAEͷվળ 1. qϕ ͷද‫ྗݱ‬Λ্͛Δ • qϕʹ͸ਖ਼‫ن‬෼෍Λ࢖͏͜ͱ͕ଟ͍͕ɺΑΓॊೈͳ෼෍Λ࢖͏͜ͱͰ Լք͕λΠτʹͳΔΑ͏ʹ͢Δ • Normalizing flow, implicit variational inference 2. ໨తؔ਺Λมߋ͢Δ • Լք͕λΠτʹͳΔΑ͏ͳ໨తؔ਺Λ࢖͏

19.

VAEͷվળ 1. qϕ ͷද‫ྗݱ‬Λ্͛Δ • qϕʹ͸ਖ਼‫ن‬෼෍Λ࢖͏͜ͱ͕ଟ͍͕ɺΑΓॊೈͳ෼෍Λ࢖͏͜ͱͰ Լք͕λΠτʹͳΔΑ͏ʹ͢Δ • Normalizing flow, implicit variational inference 2. ໨తؔ਺Λมߋ͢Δ • Լք͕λΠτʹͳΔΑ͏ͳ໨తؔ਺Λ࢖͏

20.

Importance Weighted Autoencoder IWAE pθ (x, z ) 1 log pθ (x) = log z(1),…,z(k)∼q(z) ∑ k [ i=1 q (z (i)) ] k (i) pθ (x, z ) 1 ≥ z(1),…,z(k)∼q(z) log = ℒk (θ, q) ∑ (i) k [ ] q z ( ) i=1 k 𝔼 • k = 1 ͷͱ͖ɺVAEͷม෼ԼքͱҰக 𝔼 • k → ∞ Ͱ౳߸੒ཱ (i)

21.

Importance Weighted Autoencoder IWAE k Λ૿΍͢΄Ͳੑೳ্͕͕ΓɺVAEΑΓ΋ྑ͍

22.

Stochastically Unbiased Marginalization Objective

23.

SUMO Stochastically Unbiased Marginalization Objective • IWAEͰ΋ɺk Λे෼૿΍͞ͳ͍ͱԼք͸λΠτʹͳΒͳ͍ • ԼքͰ͸ͳ͘ɺ౳߸͕ৗʹ੒ΓཱͭྔʢʹෆภਪఆྔʣͰֶश͍ͨ͠ • ෆภਪఆྔΛಘΔํ๏͸ͳ͍͔ʁ ➡ Russian roulette estimatorΛ࢖͏

24.

Russian Roulette Estimator Δk = ∞ ͱ͓͘ͱɺ‫਺ڃ‬ ∑ k=1 ℒ1 (θ, q) (k = 1) {ℒk (θ, q) − ℒk−1 (θ, q) (k ≥ 2) Δk ͸ର਺पล໬౓ͱҰக͢Δ ∞ ∑ k=1 Δk = ℒ∞ (θ, q) = log pθ (x)

25.

Russian Roulette Estimator ҎԼͷΑ͏ͳ ŷ Λߟ͑Δ y ̂ = Δ1 + ∞ ∑k=2 Δk μ ⋅ b, b ∼ Bernoulli (μ) 1. ֬཰ μ Ͱද͕ग़ΔίΠϯΛৼΔ 2. ද͕ग़ͨΒ k = 2 Ҏ߱Λ‫͠ࢉܭ‬ɺμ Ͱׂͬͨ΋ͷΛΔ1ʹ଍͢ ཪ͕ग़ͨΒ Δ1 ͚ͩΛ‫͢ࢉܭ‬Δ

26.

ŷ ͸ ∞ ∑ k=1 Δk ͷෆภਪఆྔͰ͋Δ͜ͱ͕Θ͔Δ ŷ = Δ1 + ∞ ∑k=2 Δk μ [y]̂ = Δ1 + 𝔼 𝔼 Russian Roulette Estimator ⋅ b, b ∼ Bernoulli (b; μ) ∞ ∑k=2 Δk μ ⋅ ∞ [b] = ∑ Δk k=1

27.

Russian Roulette Estimator ಉ͜͡ͱΛ k = 2 Ҏ߱΋‫܁‬Γฦ͢ͱɺҎԼͷ ŷ ΋ ∞ ∑ k=1 Δk ͷෆภਪఆྔʹͳΔ K Δk ŷ = , K ∼ Geometric K; 1 − μ ( ) ∑ μ k−1 k=1 K ͸࠷ॳʹཪ͕ग़Δ·ͰʹίΠϯΛৼͬͨճ਺ʢ‫ز‬Կ෼෍ʹै͏ʣ ͜ͷ ŷ Λ࢖͑͹ɺର਺पล໬౓ͷෆภਪఆྔ͕ಘΒΕΔ

28.

SUMO Stochastically Unbiased Marginalization Objective K Δk log pθ (x) = K∼p(K) k−1 ∑ μ [ k=1 ] 𝔼 = ℒ1 (θ, qϕ) + VAEͱಉ͡ K K∼p(K) ∑ ℒk (θ, qϕ) − ℒk−1 (θ, qϕ) μ k−1 k=2 𝔼 ิਖ਼߲

29.

𝔼 SUMO Stochastically Unbiased Marginalization Objective SUMO (x) = log w (1) + K ∑ k=2 log 1 k k (i) ∑i=1 w − log k−1 μ 1 k−1 k−1 (i) ∑i=1 w p x, z ( ) θ (i) (1) (K) w = , K ∼ p (K), z , …, z ∼ qϕ (z ∣ x) (i) qϕ (z ∣ x) (i) SUMO͸ର਺पล໬౓ͷෆภਪఆྔ K∼p(K), z (1),…,z (K)∼qϕ(z ∣ x) [SUMO (x)] = log pθ (x)

30.

SUMO ෼ࢄ௿‫ݮ‬ SUMO͸ p (K) ͷબͼํʹΑͬͯɺ෼ࢄͱ‫ྔࢉܭ‬ͷτϨʔυΦϑ͕ੜ·ΕΔ • খ͍͞K ͕ग़΍͍͢෼෍Λબ΂͹ɺ‫ྔࢉܭ‬͸‫ݮ‬ΒͤΔ͕෼ࢄ͸େ͖͘ͳΔ ࠷ॳͷ m ճ෼͸ඞͣ Δk Λ‫͢ࢉܭ‬ΔΑ͏ʹ͢Δ͜ͱͰ΋ɺ෼ࢄΛ௿‫͖Ͱݮ‬Δ m m+K−1 1 (1) SUMOm (x) = log w + ∑ m∑ i=1 k=m+1 log 1 k k (i) ∑i=1 w − log μ k−1 1 k−1 k−1 (i) ∑i=1 w

31.

SUMO Τϯίʔμͷֶश SUMO͸ɺΤϯίʔμଆ͔Β‫ͨݟ‬Βύϥϝʔλ ϕ ʹؔͯ͠ఆ਺ VAEͷΑ͏ʹɺಉ͡ϩεͰֶशͯ͠΋ҙຯ͕ͳ͍ ࿦จͰ͸ɺਪఆྔͷ෼ࢄΛ࠷খԽ͢ΔΑ͏ʹֶश͢Δ͜ͱΛఏҊ͍ͯ͠Δ 𝔼 𝕍 ∇ϕ [SUMO (x)] = (SUMO (x)) ∇ ϕ [ ] 2

32.

SUMO ࣮‫ݧ‬ʢੜ੒Ϟσϧʣ IWAE౳ΑΓҰ؏ͯ͠ੑೳ্͕͕Δ

33.

SUMO Τϯτϩϐʔ࠷େԽ ີ౓ؔ਺͸Θ͔͍ͬͯΔ͕ɺαϯϓϦϯά͕೉͍͠෼෍ p* (x) Λۙࣅ͍ͨ͠ ͜ΕΛજࡏม਺ϞσϧͰֶश͢Δͱ͖ɺreverse KLͷ࠷খԽ͕Α͘࢖ΘΕΔ min KL (pθ(x)∥ p*(x)) = min θ θ 𝔼 ୈ1߲ͷΤϯτϩϐʔ߲ͷ‫͕ࢉܭ‬೉͍͠ x∼pθ(x) [log pθ(x) − log p*(x)]

34.

Τϯτϩϐʔ࠷େԽ log pθ (x)ͷਪఆʹIWAEΛ࢖͏ͱɺ໨తؔ਺ͷԼքΛ࠷খԽͯ͠͠·͏ pθ (x, z ) 1 (x) log p ≥ log (1) (k) [ ] pθ(x) θ pθ(x),z ,…,z ∼qϕ(z ∣ x) (i) k∑ [ ] q z ( ) i=1 k 𝔼 = 𝔼 (i) (x) log p − KL (q̃θ,ϕ (z ∣ x) ∥ pθ (z ∣ x))] pθ(x) [ θ ͜͜Λ࠷େԽ͠Α͏ͱͯ͠͠·͏ SUMOΛ࢖͑͹ɺ͜ͷ໰୊ΛճආͰ͖Δɹ pθ(x) [log pθ (x)] = 𝔼 𝔼 𝔼 SUMO [SUMO (x)]

35.

SUMO ࣮‫ݧ‬ʢΤϯτϩϐʔ࠷େԽʣ IWAE͸૬౰αϯϓϧ਺Λ૿΍͞ͳ͍ͱ ్தͰֶश่͕յ͢Δ SUMO͸҆ఆֶͯ͠शͰ͖Δ ਪఆͨ͠ີ౓ؔ਺΋SUMOͷํ͕ਖ਼֬

36.

SUMO REINFORCE΁ͷԠ༻ REINFORCEͰ͸ɺlog pθ (x) ͕‫͖Ͱࢉܭ‬Δඞཁ͕͋Δ ∇θ (x) f = [ ] pθ(x) (x) (x) f ∇ log p [ ] pθ(x) θ θ pθ (x) ʹજࡏม਺ϞσϧΛ࢖͏ͱɺ؆୯ʹ‫ͳ͘ͳ͖Ͱࢉܭ‬Δ e.g., ‫ڧ‬Խֶशͷํࡦʹજࡏม਺ϞσϧΛ࢖͏ 𝔼 𝔼 πθ (a ∣ s) = pθ (z) pθ (a ∣ s, z) dz ∫

37.

SUMO REINFORCE΁ͷԠ༻ SUMOΛ࢖͑͹ɺ͜Ε΋ෆภਪఆͰ͖Δ ∇θ pθ(x) [ f (x)] = 𝔼 𝔼 𝔼 = pθ(x) [ f (x) ∇θ log pθ (x)] [ f (x) ∇θ SUMO (x)]

38.

࣮‫ݧ‬ʢ‫ڧ‬Խֶशʣ ࣌‫ྻܥ‬Λ‫͍ͳ·ؚ‬؆୯ͳ‫ڧ‬Խֶशͷ໰୊Ͱɺ x∼pθ(x)[R(x)] ͷ࠷େԽΛߟ͑Δ ֶश͸REINFORCEΛ࢖ͬͯɺํࡦޯ഑๏Ͱߦ͏ ∇θ x∼pθ(x)[R(x)] = pθ(x) [R (x) ∇θ log pθ (x)] ํࡦ pθ (x) ʹજࡏม਺ϞσϧΛ࢖͏৔߹ʹ͸ɺSUMO͕࢖͑Δ 𝔼 𝔼 𝔼 ∇θ 𝔼 𝔼 SUMO [R(x)] = x∼pθ(x) [R (x) ∇θ SUMO (x)]

39.

SUMO ࣮‫ݧ‬ʢ‫ڧ‬Խֶशʣ ํࡦ pθ (x) ͱͯ͠ 1. જࡏม਺Ϟσϧ 2. ࣗ‫ݾ‬ճ‫ؼ‬Ϟσϧ 3. ಠཱϞσϧ ͷ3ͭΛൺ΂ɺ1. ͷֶशʹIWAEͱSUMOΛ༻͍Δ৔߹΋ൺֱ͢Δ

40.

SUMO ࣮‫ݧ‬ʢ‫ڧ‬Խֶशʣ 1. જࡏม਺ϞσϧɿpLVM(x) := pθ (xi ∣ z) p(z)dz ∫∏ ද‫͘ߴ͕ྗݱ‬ɺαϯϓϦϯά΋଎͍ 2. ࣗ‫ݾ‬ճ‫ؼ‬ϞσϧɿpAutoreg (x) := p (xi ∣ x<i) ∏ ද‫ྗݱ‬͸ߴ͍͕ɺαϯϓϦϯά͕஗͍ 3. ಠཱϞσϧɿpIndep (x) := p (xi) ∏ ද‫ྗݱ‬͸௿͍͕ɺαϯϓϦϯά͕଎͍

41.

SUMO ࣮‫ݧ‬ʢ‫ڧ‬Խֶशʣ ੑೳ͸SUMOͱࣗ‫ݾ‬ճ‫ؼ‬Ϟσϧ͕ྑ͍ ࣗ‫ݾ‬ճ‫ؼ‬Ϟσϧ͸SUMOͷ19.2ഒ஗͍

42.

·ͱΊ • જࡏม਺Ϟσϧͷֶशʹ͸ɺର਺पล໬౓ͷԼքͷ࠷େԽ͕࢖ΘΕ͖ͯͨ • ຊ‫Ͱڀݚ‬͸ɺRussian roulette estimatorΛ༻͍ͯର਺पล໬౓ͷෆภਪఆྔΛ ௚઀࠷େԽ͢Δख๏SUMOΛఏҊ • SUMO͸ɺreverse KL࠷খԽ΍ɺ‫ڧ‬ԽֶशͳͲʹ΋Ԡ༻Ͱ͖Δ ‫ײ‬૝ • ൚༻ੑͷ͋ΔΞΠσΞͰɺ৭Μͳͱ͜ΖͰ࢖͑ͦ͏ʢ૬‫ޓ‬৘ใྔ‫͔ͱܥ‬ʣ