>100 Views
August 20, 21
スライド概要
2021/08/20
Deep Learning JP:
http://deeplearning.jp/seminar-2/
DL輪読会資料
SUMO: Unbiased Estimation of Log Marginal Probability for Latent Variable Models Shohei Taniguchi, Matsuo Lab 1
ॻࢽใ SUMO: Unbiased Estimation of Log Marginal Probability for Latent Variable Models Yucen Luo, Alex Beatson, Mohammad Norouzi, Jun Zhu, David Duvenaud, Ryan P. Adams, Ricky T. Q. Chen https://arxiv.org/abs/2004.00353 ICLR 2020 accepted () Θ͔Γ͢͞ͷͨΊʹɺൃදͰจதͱه๏͕ҟͳΔ෦͕͋Γ·͢
֓ཁ જࡏมϞσϧͷରपลͷෆภਪఆ • VAEͳͲͷજࡏมϞσϧͷֶशҎԼͷରपลͷ࠷େԽͰߦΘΕΔ log pθ (x) = log pθ (x, z) dz ∫ • ରपล௨ৗ͍ͳ͖ͰࢉܭͷͰɺVAEͰΘΓʹͦͷԼքΛ࠷େԽ • ຊจͰɺԼքͰͳ͘ରपลΛ࠷େԽ͢Δํ๏ΛఏҊ • Russian roulette estimatorΛ͏ςΫχοΫ͕໘ന͍
Outline 1. ෆภਪఆྔ 2. જࡏมϞσϧ (VAE, IWAE) 3. Stochastically Unbiased Marginalization Objective (SUMO) • Russian roulette estimator • ࢄݮ • ࣮ݧ
ෆภਪఆྔ
𝔼 ෆภਪఆྔ Unbiased Estimator ਪఆ͍ͨ͠ྔɿyɹਪఆྔɿŷ [y]̂ = y ͕Γཱͭͱ͖ʮŷ y ͷෆภਪఆྔͰ͋Δʯͱ͍͏
Unbiased Estimator Ex. 1: ਖ਼نͷฏۉͷਪఆ (1) , …, x (n) αϯϓϧฏ ۉx̄ = (x; μ, σ ) ͔ΒಘΒΕͨͱ͢Δ 2 ͕ਖ਼ن n (i) ∑i=1 x n [x̄] = [ 𝔼 σʔλ x 𝔼 𝔼 𝒩 ෆภਪఆྔ ฏ ۉμ ͷෆภਪఆྔ n (i) ∑i=1 x n n 1 (i) = x ]=μ [ ∑ n ] i=1
ෆภਪఆྔ Unbiased Estimator Ex. 2: ϛχόονֶश n 1 (i) ݧܦଛࣦ Ln = l (x , θ) n∑ i=1 m 1 (i) ̂ ϛχόον౷ ྔܭLm = l (x̃ , θ) Ln ͷෆภਪఆྔ m∑ i=1 (1) x̃ x , …, x (n) ͔ΒҰ༷ϥϯμϜʹબΕͨͷ
ෆภਪఆྔ Unbiased Estimator Ex. 3: Reparameterization trick (x) (x) Λਪఆ͍ͨ͠ f p = [ ] pθ(x) ( θ ∇θ (ϵ; 0,1) ͱͯ͠ ∇θ f (μθ + ϵ ⋅ σθ) ∇θ ϵ∼ ϵ∼ (ϵ; 0,1) [ ∇θ 𝔼 𝒩 𝒩 𝒩 𝔼 𝔼 pθ(x) [ f (x)] ͷෆภਪఆྔ f (μθ + ϵ ⋅ σθ)] = ∇θ • VAEͷΤϯίʔμͷޯਪఆʹΘΕΔ 𝔼 (x; μθ, σθ)) pθ(x) [ f (x)]
Unbiased Estimator Ex. 4: Likelihood ratio gradient estimator (REINFORCE) ∇θ (x) f Λਪఆ͍ͨ͠ [ ] pθ(x) x ∼ pθ (x) ͱͯ͠ f (x) ∇θ log pθ (x) ∇θ pθ(x) [ f (x)] ͷෆภਪఆྔ (x) (x) f ∇ log p = ∇ [ ] pθ(x) θ θ θ 𝔼 𝔼 • ڧԽֶशͷํࡦޯ๏ͰΘΕΔ 𝔼 𝔼 ෆภਪఆྔ (x) f [ ] pθ(x)
ෆภਪఆྔ ͳͥෆภੑ͕ॏཁ͔ • ػցֶशͰɺݧܦଛࣦ͕࠷খʹͳΔύϥϝʔλΛޯ๏Ͱ୳͢ • ݧܦଛࣦ͕͍ͳ͖Ͱࢉܭ߹Ͱɺݧܦଛࣦͷෆภਪఆྔ͕͖ͰࢉܭΕ ॴہղͷऩଋ͕อূͰ͖Δ߹͕ଟ͍ e.g., ֬తޯ߱Լ๏ (stochastic gradient descent, SGD) ϛχόονਪఆྔΛ༻͍ͨޯ๏ɺదʹֶशΛεέδϡʔϧ͢Ε ॴہղͷऩଋ͕อূ͞Ε͍ͯΔ
༗ޮਪఆྔ Efficient Estimator ਪఆྔ͕ෆภੑΛ͍ͬͯͯɺࢄ͕େ͖͍ͱ҆ఆͨ͠ਪఆ͕Ͱ͖ͳ͍ Ex. 1: SGDͰόοναΠζ͕খ͍͞ͱࢄ͕େ͖͘ͳΓֶश͕҆ఆ͠ͳ͍ Ex. 2: Reparameterization trickҰൠʹlikelihood ratio estimatorΑΓࢄ ཧతͳਪఆྔ ෆภਪఆྔ͔ͭࢄ͕খ͍͞ͷ ෆภਪఆྔͷதͰࢄ͕࠷খͱͳΔͷΛಛʹ༗ޮਪఆྔͱ͍͏
જࡏมϞσϧ
જࡏมϞσϧ Latent Variable Models ੜϞσϧͰΑ͘ΘΕΔϞσϧ pθ (x) = pθ (x, z) dz ∫ ύϥϝʔλ θ ͷֶशɺରपลͷ࠷େԽͰߦ͏ log pθ (x) = log pθ (x, z) dz ∫
มਪ Variational Inference ରपล͍ͳ͖ͰࢉܭͷͰɺมԼքΛ༻͍Δ log pθ (x) = log pθ (x, z) dz ∫ 𝔼 pθ (x, z) ≥ q(z) log = ℒ (θ, q) [ q (z) ] ͜ͷෆࣜ q (z) = pθ (z ∣ x) ͷͱ͖߸ཱ ➡ q (z) Λ pθ (z ∣ x) ʹͳΔ͍͔ۙ͘Βબྑ͍
มࣗ߸ූݾԽث Variational Autoencoder, VAE q (z) ʹύϥϝʔλΛͨͤͯ qϕ (z ∣ x) ͱͯ͠ಉ࣌ʹֶश͢Δ తؔ pθ (z ∣ x)ͱͷKL divergenceͷ࠷খԽ KL (qϕ (z ∣ x) ∥ pθ (z ∣ x)) = log pθ (x) − ℒ (θ, qϕ) ୈ1߲ ϕ ʹґଘ͠ͳ͍ͷͰɺ݁ ہθ ͱ ϕ ͱʹ ℒ ͷ࠷େԽͰֶशͰ͖Δ ϕ ͷޯͷਪఆʹɺઌड़ͷreparameterization trickΛ͏
VAEͷ՝ • ੜϞσϧͷύϥϝʔλ θ ͷֶशɺৗʹ qϕ (z ∣ x) ʹґଘ͢Δ • qϕ (z ∣ x) ͕ pθ (z ∣ x) ͔ΒΕ͍ͯΔͱԼք͕؇͘ͳΓɺຊདྷ࠷େԽ͍ͨ͠ ରपล͔Β͔͚ΕͨͷΛ࠷େԽͯ͠͠·͏ https://tips-memo.com/python-emalgorithm-gmm
VAEͷվળ 1. qϕ ͷදྗݱΛ্͛Δ • qϕʹਖ਼نΛ͏͜ͱ͕ଟ͍͕ɺΑΓॊೈͳΛ͏͜ͱͰ Լք͕λΠτʹͳΔΑ͏ʹ͢Δ • Normalizing flow, implicit variational inference 2. తؔΛมߋ͢Δ • Լք͕λΠτʹͳΔΑ͏ͳతؔΛ͏
VAEͷվળ 1. qϕ ͷදྗݱΛ্͛Δ • qϕʹਖ਼نΛ͏͜ͱ͕ଟ͍͕ɺΑΓॊೈͳΛ͏͜ͱͰ Լք͕λΠτʹͳΔΑ͏ʹ͢Δ • Normalizing flow, implicit variational inference 2. తؔΛมߋ͢Δ • Լք͕λΠτʹͳΔΑ͏ͳతؔΛ͏
Importance Weighted Autoencoder IWAE pθ (x, z ) 1 log pθ (x) = log z(1),…,z(k)∼q(z) ∑ k [ i=1 q (z (i)) ] k (i) pθ (x, z ) 1 ≥ z(1),…,z(k)∼q(z) log = ℒk (θ, q) ∑ (i) k [ ] q z ( ) i=1 k 𝔼 • k = 1 ͷͱ͖ɺVAEͷมԼքͱҰக 𝔼 • k → ∞ Ͱ߸ཱ (i)
Importance Weighted Autoencoder IWAE k Λ૿͢΄Ͳੑೳ্͕͕ΓɺVAEΑΓྑ͍
Stochastically Unbiased Marginalization Objective
SUMO Stochastically Unbiased Marginalization Objective • IWAEͰɺk Λे૿͞ͳ͍ͱԼքλΠτʹͳΒͳ͍ • ԼքͰͳ͘ɺ߸͕ৗʹΓཱͭྔʢʹෆภਪఆྔʣͰֶश͍ͨ͠ • ෆภਪఆྔΛಘΔํ๏ͳ͍͔ʁ ➡ Russian roulette estimatorΛ͏
Russian Roulette Estimator Δk = ∞ ͱ͓͘ͱɺڃ ∑ k=1 ℒ1 (θ, q) (k = 1) {ℒk (θ, q) − ℒk−1 (θ, q) (k ≥ 2) Δk ରपลͱҰக͢Δ ∞ ∑ k=1 Δk = ℒ∞ (θ, q) = log pθ (x)
Russian Roulette Estimator ҎԼͷΑ͏ͳ ŷ Λߟ͑Δ y ̂ = Δ1 + ∞ ∑k=2 Δk μ ⋅ b, b ∼ Bernoulli (μ) 1. ֬ μ Ͱද͕ग़ΔίΠϯΛৼΔ 2. ද͕ग़ͨΒ k = 2 Ҏ߱Λ͠ࢉܭɺμ ͰׂͬͨͷΛΔ1ʹ͢ ཪ͕ग़ͨΒ Δ1 ͚ͩΛ͢ࢉܭΔ
ŷ ∞ ∑ k=1 Δk ͷෆภਪఆྔͰ͋Δ͜ͱ͕Θ͔Δ ŷ = Δ1 + ∞ ∑k=2 Δk μ [y]̂ = Δ1 + 𝔼 𝔼 Russian Roulette Estimator ⋅ b, b ∼ Bernoulli (b; μ) ∞ ∑k=2 Δk μ ⋅ ∞ [b] = ∑ Δk k=1
Russian Roulette Estimator ಉ͜͡ͱΛ k = 2 Ҏ߱܁Γฦ͢ͱɺҎԼͷ ŷ ∞ ∑ k=1 Δk ͷෆภਪఆྔʹͳΔ K Δk ŷ = , K ∼ Geometric K; 1 − μ ( ) ∑ μ k−1 k=1 K ࠷ॳʹཪ͕ग़Δ·ͰʹίΠϯΛৼͬͨճʢزԿʹै͏ʣ ͜ͷ ŷ Λ͑ɺରपลͷෆภਪఆྔ͕ಘΒΕΔ
SUMO Stochastically Unbiased Marginalization Objective K Δk log pθ (x) = K∼p(K) k−1 ∑ μ [ k=1 ] 𝔼 = ℒ1 (θ, qϕ) + VAEͱಉ͡ K K∼p(K) ∑ ℒk (θ, qϕ) − ℒk−1 (θ, qϕ) μ k−1 k=2 𝔼 ิਖ਼߲
𝔼 SUMO Stochastically Unbiased Marginalization Objective SUMO (x) = log w (1) + K ∑ k=2 log 1 k k (i) ∑i=1 w − log k−1 μ 1 k−1 k−1 (i) ∑i=1 w p x, z ( ) θ (i) (1) (K) w = , K ∼ p (K), z , …, z ∼ qϕ (z ∣ x) (i) qϕ (z ∣ x) (i) SUMOରपลͷෆภਪఆྔ K∼p(K), z (1),…,z (K)∼qϕ(z ∣ x) [SUMO (x)] = log pθ (x)
SUMO ࢄݮ SUMO p (K) ͷબͼํʹΑͬͯɺࢄͱྔࢉܭͷτϨʔυΦϑ͕ੜ·ΕΔ • খ͍͞K ͕ग़͍͢ΛબɺྔࢉܭݮΒͤΔ͕ࢄେ͖͘ͳΔ ࠷ॳͷ m ճඞͣ Δk Λ͢ࢉܭΔΑ͏ʹ͢Δ͜ͱͰɺࢄΛ͖ͰݮΔ m m+K−1 1 (1) SUMOm (x) = log w + ∑ m∑ i=1 k=m+1 log 1 k k (i) ∑i=1 w − log μ k−1 1 k−1 k−1 (i) ∑i=1 w
SUMO Τϯίʔμͷֶश SUMOɺΤϯίʔμଆ͔ΒͨݟΒύϥϝʔλ ϕ ʹؔͯ͠ఆ VAEͷΑ͏ʹɺಉ͡ϩεͰֶशͯ͠ҙຯ͕ͳ͍ จͰɺਪఆྔͷࢄΛ࠷খԽ͢ΔΑ͏ʹֶश͢Δ͜ͱΛఏҊ͍ͯ͠Δ 𝔼 𝕍 ∇ϕ [SUMO (x)] = (SUMO (x)) ∇ ϕ [ ] 2
SUMO ࣮ݧʢੜϞσϧʣ IWAEΑΓҰ؏ͯ͠ੑೳ্͕͕Δ
SUMO Τϯτϩϐʔ࠷େԽ ີؔΘ͔͍ͬͯΔ͕ɺαϯϓϦϯά͕͍͠ p* (x) Λۙࣅ͍ͨ͠ ͜ΕΛજࡏมϞσϧͰֶश͢Δͱ͖ɺreverse KLͷ࠷খԽ͕Α͘ΘΕΔ min KL (pθ(x)∥ p*(x)) = min θ θ 𝔼 ୈ1߲ͷΤϯτϩϐʔ߲ͷ͕ࢉܭ͍͠ x∼pθ(x) [log pθ(x) − log p*(x)]
Τϯτϩϐʔ࠷େԽ log pθ (x)ͷਪఆʹIWAEΛ͏ͱɺతؔͷԼքΛ࠷খԽͯ͠͠·͏ pθ (x, z ) 1 (x) log p ≥ log (1) (k) [ ] pθ(x) θ pθ(x),z ,…,z ∼qϕ(z ∣ x) (i) k∑ [ ] q z ( ) i=1 k 𝔼 = 𝔼 (i) (x) log p − KL (q̃θ,ϕ (z ∣ x) ∥ pθ (z ∣ x))] pθ(x) [ θ ͜͜Λ࠷େԽ͠Α͏ͱͯ͠͠·͏ SUMOΛ͑ɺ͜ͷΛճආͰ͖Δɹ pθ(x) [log pθ (x)] = 𝔼 𝔼 𝔼 SUMO [SUMO (x)]
SUMO ࣮ݧʢΤϯτϩϐʔ࠷େԽʣ IWAE૬αϯϓϧΛ૿͞ͳ͍ͱ ్தͰֶश่͕յ͢Δ SUMO҆ఆֶͯ͠शͰ͖Δ ਪఆͨ͠ີؔSUMOͷํ͕ਖ਼֬
SUMO REINFORCEͷԠ༻ REINFORCEͰɺlog pθ (x) ͕͖ͰࢉܭΔඞཁ͕͋Δ ∇θ (x) f = [ ] pθ(x) (x) (x) f ∇ log p [ ] pθ(x) θ θ pθ (x) ʹજࡏมϞσϧΛ͏ͱɺ؆୯ʹͳ͘ͳ͖ͰࢉܭΔ e.g., ڧԽֶशͷํࡦʹજࡏมϞσϧΛ͏ 𝔼 𝔼 πθ (a ∣ s) = pθ (z) pθ (a ∣ s, z) dz ∫
SUMO REINFORCEͷԠ༻ SUMOΛ͑ɺ͜ΕෆภਪఆͰ͖Δ ∇θ pθ(x) [ f (x)] = 𝔼 𝔼 𝔼 = pθ(x) [ f (x) ∇θ log pθ (x)] [ f (x) ∇θ SUMO (x)]
࣮ݧʢڧԽֶशʣ ࣌ྻܥΛ͍ͳ·ؚ؆୯ͳڧԽֶशͷͰɺ x∼pθ(x)[R(x)] ͷ࠷େԽΛߟ͑Δ ֶशREINFORCEΛͬͯɺํࡦޯ๏Ͱߦ͏ ∇θ x∼pθ(x)[R(x)] = pθ(x) [R (x) ∇θ log pθ (x)] ํࡦ pθ (x) ʹજࡏมϞσϧΛ͏߹ʹɺSUMO͕͑Δ 𝔼 𝔼 𝔼 ∇θ 𝔼 𝔼 SUMO [R(x)] = x∼pθ(x) [R (x) ∇θ SUMO (x)]
SUMO ࣮ݧʢڧԽֶशʣ ํࡦ pθ (x) ͱͯ͠ 1. જࡏมϞσϧ 2. ࣗݾճؼϞσϧ 3. ಠཱϞσϧ ͷ3ͭΛൺɺ1. ͷֶशʹIWAEͱSUMOΛ༻͍Δ߹ൺֱ͢Δ
SUMO ࣮ݧʢڧԽֶशʣ 1. જࡏมϞσϧɿpLVM(x) := pθ (xi ∣ z) p(z)dz ∫∏ ද͘ߴ͕ྗݱɺαϯϓϦϯά͍ 2. ࣗݾճؼϞσϧɿpAutoreg (x) := p (xi ∣ x<i) ∏ දྗݱߴ͍͕ɺαϯϓϦϯά͕͍ 3. ಠཱϞσϧɿpIndep (x) := p (xi) ∏ දྗݱ͍͕ɺαϯϓϦϯά͕͍
SUMO ࣮ݧʢڧԽֶशʣ ੑೳSUMOͱࣗݾճؼϞσϧ͕ྑ͍ ࣗݾճؼϞσϧSUMOͷ19.2ഒ͍
·ͱΊ • જࡏมϞσϧͷֶशʹɺରपลͷԼքͷ࠷େԽ͕ΘΕ͖ͯͨ • ຊͰڀݚɺRussian roulette estimatorΛ༻͍ͯରपลͷෆภਪఆྔΛ ࠷େԽ͢Δख๏SUMOΛఏҊ • SUMOɺreverse KL࠷খԽɺڧԽֶशͳͲʹԠ༻Ͱ͖Δ ײ • ൚༻ੑͷ͋ΔΞΠσΞͰɺ৭Μͳͱ͜ΖͰ͑ͦ͏ʢ૬ޓใྔ͔ͱܥʣ