>100 Views
July 01, 22
スライド概要
2022/7/1
Deep Learning JP
http://deeplearning.jp/seminar-2/
DL輪読会資料
Unbiased Gradient Estimation for Marginal Log-likelihood Shohei Taniguchi, Matsuo Lab
എܠ • ਂੜϞσϧͷֶशɼपลର log pθ (x) ͷ࠷େԽͰఆࣜԽ͞ΕΔ ∫ જࡏมϞσϧʢe.g., VAE)ɿlog pθ (x) = log pθ (x, z) dz ΤωϧΪʔϕʔεϞσϧɿlog pθ (x) = − Eθ (x) − log exp (−Eθ (x)) dx ∫ • ଟ͘ͷ߹Ͱɼपลରղੳతʹ͍ͳ͖Ͱࢉܭ 2
എܠ ྫ̍ɿVAE ରपลͷมԼքͰۙࣅ 𝔼 log pθ (x) = log pθ (x, z) dz ∫ pθ (x, z) ≥ q(z) log = ℒ (θ, q) [ q (z) ] 3
എܠ ྫ̎ɿEBM ରपลͷޯΛMCMCͰۙࣅ ∇log pθ (x+) = − ∇Eθ (x+) + ≈ − ∇Eθ (x+) + (x ) ∇E [ x−∼pθ(x) θ − ] xT ∼q(xT) [ ∇Eθ (xT)] 𝔼 𝔼 q (xT)MCMCΛTεςοϓճͨ͠ͰɼT → ∞Ͱ߸ཱ 4
എܠ • มԼք༗ݶεςοϓͷMCMCʹΑΔޯͷۙࣅʹɼόΠΞε͕͋Δ όΠΞεɿਪఆྔͷظͱਅͱͷͣΕ ɹɹɹɹɹόΠΞε͕0ͷਪఆྔΛෆภਪఆྔͱ͍͏ • Ͱ͖ΕɼରपลͷޯͷෆภਪఆྔΛֶͬͯश͍ͨ͠ • ରपลΛෆภਪఆ͢Δख๏ͷྫʢաڈͷྠಡʣ https://www.slideshare.net/DeepLearningJP2016/dlsumo-unbiased-estimationof-log-marginal-probability-for-latent-variable-models-250013351 5
Outline पลରͷෆภਪఆ๏ 1. पลରࣗମΛਪఆ͢Δํ๏ • On Multilevel Monte Carlo Unbiased Gradient Estimation for Deep Latent Variable Models (AISTATS 2021) • Efficient Debiased Evidence Estimation by Multilevel Monte Carlo Sampling (UAI 2021) 2. पลରͷޯΛਪఆ͢Δํ๏ • Unbiased Contrastive Divergence Algorithm for Training Energy-Based Latent Variable Models (ICLR 2020) 6
Outline पลରͷෆภਪఆ๏ 1. पลରࣗମΛਪఆ͢Δํ๏ • On Multilevel Monte Carlo Unbiased Gradient Estimation for Deep Latent Variable Models (AISTATS 2021) • Efficient Debiased Evidence Estimation by Multilevel Monte Carlo Sampling (UAI 2021) 2. पลରͷޯΛਪఆ͢Δํ๏ • Unbiased Contrastive Divergence Algorithm for Training Energy-Based Latent Variable Models (ICLR 2020) 7
IWAE Importance Weighted Autoencoder pθ (x, z ) 1 log pθ (x) = log z(1),…,z(k)∼q(z) ∑ k [ i=1 q (z (i)) ] k (i) pθ (x, z ) 1 ≥ z(1),…,z(k)∼q(z) log ∑ (i) k [ ] q z ( ) i=1 𝔼 k 𝔼 = z (1),…,z (k)∼q(z) [ℒk (θ, q)] 𝔼 k = 1 ͰVAEͱҰகɼk → ∞ Ͱ߸ཱ (i)
Δk = ∞ ͱ͓͘ͱɺڃ ∑ k=1 ℒ1 (θ, q) (k = 1) {ℒk (θ, q) − ℒk−1 (θ, q) (k ≥ 2) Δk ͷظରपลͱҰக͢Δ ∞ Δk = ∑ [ k=1 ] 𝔼 𝔼 पลରͷڃදه (x) ℒ θ, q = log p ( ) ∞ θ [ ]
Russian Roulette Estimator ҎԼͷΑ͏ͳ ŷ Λߟ͑Δ y ̂ = Δ1 + ∞ ∑k=2 Δk μ ⋅ b, b ∼ Bernoulli (μ) 1. ֬ μ Ͱද͕ग़ΔίΠϯΛৼΔ 2. ද͕ग़ͨΒ k = 2 Ҏ߱Λ͠ࢉܭɺμ ͰׂͬͨͷΛΔ1ʹ͢ ཪ͕ग़ͨΒ Δ1 ͚ͩΛ͢ࢉܭΔ
ŷ ∞ ∑ k=1 Δk ͷෆภਪఆྔͰ͋Δ͜ͱ͕Θ͔Δ ŷ = Δ1 + ∞ ∑k=2 Δk μ [y]̂ = Δ1 + 𝔼 𝔼 Russian Roulette Estimator ⋅ b, b ∼ Bernoulli (b; μ) ∞ ∑k=2 Δk μ ⋅ ∞ [b] = ∑ Δk k=1
Russian Roulette Estimator ಉ͜͡ͱΛ k = 2 Ҏ߱܁Γฦ͢ͱɺҎԼͷ ŷ ∞ ∑ k=1 Δk ͷෆภਪఆྔʹͳΔ K Δk ŷ = , K ∼ Geometric K; 1 − μ ( ) ∑ μ k−1 k=1 K ࠷ॳʹཪ͕ग़Δ·ͰʹίΠϯΛৼͬͨճʢزԿʹै͏ʣ ͜ͷ ŷ Λ͑ɺରपลͷෆภਪఆྔ͕ಘΒΕΔ
𝔼 Single Sample Estimator ಉ༷ʹͯ͠ɼҎԼͷ ŷ ∞ ∑ k=1 Δk ͷෆภਪఆྔʹͳΔ ΔK ΔK ŷ = = p (K) μ K−1 (1 − μ) ∞ ∞ ΔK = Δk [y]̂ = ∑ p (K) ⋅ ∑ (K) p K=1 k=1
SUMO Stochastically Unbiased Marginalization Objective K Δk log pθ (x) = K∼p(K) k−1 ∑ μ [ k=1 ] 𝔼 = ℒ1 (θ, qϕ) + VAEͱಉ͡ K K∼p(K) ∑ ℒk (θ, qϕ) − ℒk−1 (θ, qϕ) μ k−1 k=2 𝔼 ิਖ਼߲
SUMOͷ՝ • ਪఆྔͷࢄ͕େ͖͘ͳΓ͍͢ • ࠷ѱͷ߹ɼࢄ͕ແݶେʹൃࢄ͢Δ • ࢄ p (K) ͷબͼํͰ੍͖ͰޚΔ͕ɼࢄΛԼ͛Α͏ͱ͢Δͱɼ K ͕େ͖͘ͳΓɼ͕ྔࢉܭେ͖͘ͳΔ • ࢄ͕༗ྔࢉܭ͔ͭݶͷظ༗͋ͰݶΔ͜ͱ͕ཧ 15
Δk = ∞ ͱ͓͘ͱɺڃ ∑ k=1 ℒ1 (θ, q) (k = 1) {ℒk (θ, q) − ℒk−1 (θ, q) (k ≥ 2) Δk ͷظରपลͱҰக͢Δ ∞ Δk = ∑ [ k=1 ] 𝔼 𝔼 पลରͷڃදهʢ࠶ʣ (x) ℒ θ, q = log p ( ) ∞ θ [ ] Δkͷઃํܭ๏ଞʹߟ͑ΒΕΔ
पลରͷڃදهʢվʣ Δk = ℒ1 (θ, q) 1 ℒ2k (θ, q)− 2 ∞ ͜ͷ߹ɼڃ ∑ k=1 (1) ℒ θ, q k−1 ( ) ( 2 + (2) ℒ2k−1 (θ, q)) (k = 1) (k ≥ 2) Δk ͷظରपลͱҰக͢Δ ͜ΕΛͬͯߏͨ͠ਪఆྔɼࢄ༗ྔࢉܭ͔ͭݶ༗ݶΛ࣮͖ͰݱΔ
࣮ݧ ը૾ੜ ఏҊ๏IWAESUMOΑΓੑೳ͕վળ͢Δ
Outline पลରͷෆภਪఆ๏ 1. पลରࣗମΛਪఆ͢Δํ๏ • On Multilevel Monte Carlo Unbiased Gradient Estimation for Deep Latent Variable Models (AISTATS 2021) • Efficient Debiased Evidence Estimation by Multilevel Monte Carlo Sampling (UAI 2021) 2. पลରͷޯΛਪఆ͢Δํ๏ • Unbiased Contrastive Divergence Algorithm for Training Energy-Based Latent Variable Models (ICLR 2020) 19
EBMͷֶश ରपลͷޯΛMCMCͰۙࣅ ∇log pθ (x+) = − ∇Eθ (x+) + ≈ − ∇Eθ (x+) + x−∼pθ(x) [ ∇Eθ (x−)] xT ∼q(xT) [ ∇Eθ (xT)] q (xT)MCMCΛTεςοϓճͨ͠ͰɼT → ∞Ͱ߸ཱ 𝔼 𝔼 T͕༗ͱͩݶɼޯͷਪఆʹόΠΞε͕ͷΔ 20
ޯͷڃల։ (x ) ∇E [ x−∼pθ(x) θ − ] ∑ t=k+1 ∇E x − ( ) θ t+1 ] xt+1∼q(xt+1) [ ∇E x ( θ t)] xt∼q(xt) [ 𝔼 ∇E x + ( ) θ k ] xk∼q(xk) [ ∞ pθ (x) = q (x∞)Ͱ͋Δ͜ͱΛ༻͍Δͱɼڃͷʹܗม͖ͰܗΔ 𝔼 𝔼 𝔼 = x∞∼q(x∞) [ ∇Eθ (x∞)] 𝔼 = 21
ޯͷڃల։ (x ) ∇E [ x−∼pθ(x) θ − ] t=k+1 𝔼 ∑ ∇E x − ( ) θ t+1 ] xt+1∼q(xt+1) [ t=k+1 ∞ 𝔼 ∇E x + ( ) θ k ] xk∼q(xk) [ ∑ ∇E x − ( ) θ t+1 ] xt+1∼q(xt+1) [ 𝔼 ͨͩ͠ɼyt xt ͱಉ͡पลʹै͏ 𝔼 = ∇E x + ( ) θ k ] xk∼q(xk) [ 𝔼 𝔼 𝔼 = ∞ 22 ∇E x ( θ t)] xt∼q(xt) [ ∇E y ( θ t)] yt∼q(yt) [
ޯͷڃల։ (x ) ∇E [ x−∼pθ(x) θ − ] t=k+1 ∞ ∇E x − ∇E y ( ) ( ) θ t+1 θ t ( )] ∑ t=k+1 ͨͩ͠ɼyt xt ͱಉ͡पลʹै͏ 𝔼 [ ∇Eθ (xk) + ∑ ∇E x − ( ) θ t+1 ] xt+1∼q(xt+1) [ 𝔼 = ∇E x + ( ) θ k ] xk∼q(xk) [ 𝔼 𝔼 𝔼 = ∞ 23 ∇E y ( θ t)] yt∼q(yt) [
ޯͷڃల։ (x ) ∇E [ x−∼pθ(x) θ − ] 𝔼 = [ ∇Eθ (xk) + ∇Eθ (xt+1) − ∇Eθ (yt)) ( ∑ ] t=k+1 τ−1 ∇Eθ (xt+1) − ∇Eθ (yt)) ( ∑ ] t=k+1 𝔼 = [ ∇Eθ (xk) + ∞ 𝔼 ͨͩ͠ɼyt xt ͱಉ͡पลʹै͍ɼt ≥ τ Ͱ xt+1 = yt Ͱ͋Δͱ͢Δ 24
ޯͷڃల։ (x ) ∇E [ x−∼pθ(x) θ − ] = [ ∇Eθ (xk) + τ−1 ∇E x − ∇E y ( ) ( ) θ t+1 θ t ( )] ∑ t=k+1 ͔͠͠ɼt ≥ τ Ͱ xt+1 = ytΛຬͨ͢ yt ͲͷΑ͏ʹ࡞Δʁ 𝔼 𝔼 ແڃݶΛ༗ݶʹॻ͖͑ΒΕͨʂ 25
ΧοϓϦϯά • ૬ؔͷ͋Δ2ͭͷMCMCΛճͯ͠ 1εςοϓͣΕͨαϯϓϧಉ͕࢜ Ұக͢Δ·Ͱճ͢ • MCMCΛಠཱʹճͯ͠͠·͏ͱ αϯϓϧಉ͕࢜Ұக͢Δ͜ͱ͕΄ͱΜͲ͜ىΒͳ͍ͷͰ ͏·͘૬ؔͤ͞Δ͜ͱ͕ٕज़తͳ؊ 26
ΧοϓϦϯά • ξ ͱ η ΛͦΕͧΕp (ξ), q (η)ʹै͏֬มͱ͢Δ • ͨͩ͠ɼ྆ऀಠཱͰ͋Δඞཁͳ͍ ∫ min{p(x), q(x)}dx • ͜ͷͱ͖ɼξ ͱ η͕Ұக͢Δ֬ P (ξ = η) ʹ͍ͭͯɼҎԼͷෆཱ͕ࣜ P(ξ = η) ≤ 1 − ∥p − q∥TV = min{p(x), q(x)}dx ∫ 27
Maximal Coupling https://commons.wikimedia.org/wiki/File:Total_variation_distance.svg P(ξ = η) ≤ 1 − ∥p − q∥TV = min{p(x), q(x)}dx ∫ • ҎԼͷखॱͰ ξ ͱ η ΛαϯϓϦϯά͢Δͱɼ͜ͷ্քΛୡͰ͖Δ 28
Maximal Coupling https://commons.wikimedia.org/wiki/File:Total_variation_distance.svg खॱ 1. p ͔Β ξ ΛαϯϓϦϯά 2. ֬ α = min (q (ξ) /p (ξ), 1) Ͱ η = ξ ͱ͢Δ • ਤͷ͍෦͔Β η Λαϯϓϧ 3. ֬ 1 − α Ͱ q̃ (η) ∝ min (q (η) − p (η), 0) ͔Βαϯϓϧ͢Δ • q ͷΓͷ෦͔Β η Λαϯϓϧ 29
Maximal Coupling ྫɿͲͪΒඪ४ਖ਼نͷ߹ 1. p ͔Βαϯϓϧ ξ ΛಘΔ 2. η = ξ ͱͯ͠ɼq ͔Βͷ αϯϓϧͱ͢Δ https://colcarroll.github.io/couplings/static/maximal_couplings.html • ͜ͷͱ͖ɼξ ͱ η ͱʹඪ४ਖ਼نʹै͍ɼ֬1Ͱ ξ = η 30
Maximal Coupling ͦͷଞͷྫ https://colcarroll.github.io/couplings/static/maximal_couplings.html 31
ΧοϓϦϯά • MCMCͷ֤εςοϓΛmaximal couplingͰαϯϓϦϯά͢Δ͜ͱͰ ߴ͍֬Ͱ2ͭͷαϯϓϧΛҰகͤ͞Δ͜ͱ͕Ͱ͖Δ • ΧοϓϦϯάΛͬͨMCMCͷෆภԽͷॳग़[Jacob, et al., 2017] https://arxiv.org/abs/1708.03625 • ͜ͷจɼ͜ΕΛEBMʢಛʹRBMʣͷ ֶशʹ༻͍ͨͷ 32
࣮ݧ τΠσʔλ ී௨ͷMCMCͰ్த͔Βੑೳ͕ྼԽ͢Δ͕ఏҊ๏ͦΕ͕͜ىΒͳ͍ ΧοϓϦϯάʹ͔͔Δεςοϓଟͯ͘10εςοϓఔ 33
࣮ݧ ը૾ੜʢFashion MNISTʣ ը૾ੜͰಉ༷ͷ 34
·ͱΊ • पลର͓ΑͼͦͷޯΛෆภਪఆ͢ΔͨΊͷςΫχοΫΛհ 1. ϩγΞϯϧʔϨοτਪఆΛ༗ݶࢄɾ༗ํ͏ߦͰྔࢉܭݶ๏ 2. ΧοϓϦϯάΛͬͯMCMCΛෆภʹ͢Δํ๏ • ͲͪΒςΫχΧϧʹ໘ന͘ɼ৭ʑԠ༻͕ޮ͖ͦ͏ • ࣮͕ݧখنͳͷͰɼେ͖͍ϞσϧͰ͑Δͷ͔͕ͳʹؾΔ 35