>100 Views
April 22, 22
スライド概要
2022/04/22
Deep Learning JP:
http://deeplearning.jp/seminar-2/
DL輪読会資料
DEEP LEARNING JP [DL Papers] Causality Inspired Representation Learning for Domain Generalization Yuting Lin, Kokusai Kogyo Co., Ltd.(国際航業) http://deeplearning.jp/ 1
書誌情報 • タイトル – Causality Inspired Representation Learning for Domain Generalization • 著者 – Fangrui Lv1 Jian Liang2 Shuang Li1,∗ Bin Zang1 Chi Harold Liu1 Ziteng Wang3 Di Liu2 – 1Beijing Institute of Technology, China 2Alibaba Group, China 3Yizhun Medical AI Co., Ltd, China • CVPR2022に採択 • Paper – https://arxiv.org/abs/2203.14237 • Code – https://github.com/BIT-DA/CIRL 2
概要 • DGに向けた、因果関係を用いるrepresentation label-related causal factor (domain independent): 𝑆 domain-related non-causal factor(label independent): 𝑈 入力に𝑋に𝑆と𝑈が混在し、𝑋 = 𝑓 𝑆, 𝑈 を直接分解するのが困難 • Sに潜在的なnon-causal情報が混在 • 互いに独立していない因数に分解すると、過剰なSになってしまう 3
概要 • Causality Inspired Representation Learning (CIRL)によるdomain generalization手法の提案 陰的な因果メカニズムを発掘することで、汎化力を向上 causal intervention moduleでSとUを分離する同時に、別domainのXを生成 • domain不変なrepresentationを学習 factorization moduleでrepresentationの各dimensionが互い独立になるように学習 • 理想的なSを推定 taskに向けたcausal factorになるように、adversarial mask moduleで学習された representationが有効的なcausa factorになる 4
既往研究 – domain generalization • 目的:source domainから、unseen target domainに汎用的なモデルの作成 • 既存手法 domain不変representationの学習 • kernel-based optimization、adversarial learning、second-order correlation、Variational Bayes等 Data augmentationでsource domainのバリエーションを増やす • Discriminatorの勾配で入力データを摂動・生成、domain augmentation meta learning • domain shiftをmeta-trainとmeta-testの違いをみなす low-rank decomposition、multi-task learning、gradient-guided dropout等の提案も 5
既往研究 - Causal Mechanism • source domainから離れている推論によく使用される (causal diagram等)厳密な仮定が必要とされてきた • MatchDGはDGに因果関係を導入 contrastive learningで異なるsource domainから不変なrepresentationを学習 提案手法はdimension-wise representationsからcausal factorsを抽出するため、 厳密な仮定が不要 6
背景 • Principle 1: Common Cause Principle 変数XとYが相関する場合、変数Sが存在。Sが①両方とも因果的に影響する②Sを条 件とする場合、XとYの独立させる全ての依存関係を説明する DGをstructural causal model(SCM)で定式化 • 𝑋: = 𝑓 𝑆, 𝑈, 𝑉1 , 𝑆 ⫫ 𝑈 ⫫ 𝑉1 • 𝑌: = ℎ 𝑆, 𝑉2 = ℎ 𝑔 𝑋 , 𝑉2 , 𝑉1 ⫫ 𝑉2 where X=input image, Y=label, S=causal factor, U=non-causal factor, V1,V2=jointly independent noise • Sが分かれば、ℎ∗ = arg minℎ 𝔼𝑃 ℓ ℎ 𝑔 𝑋 , 𝑌 = arg min 𝔼𝑃 ℓ ℎ 𝑆 , 𝑌 でℎを最適化することで、 ℎ 汎用的なモデルを作成可能 • Sを直接推定できない 7
背景 • Principle 2: Independent Causal Mechanisms (ICM) Principle Causeが与えられた場合、各変数の条件付き分布は互い独立 • ①Causal factor set 𝑠1 , 𝑠2 , ⋯ , 𝑠𝑁 に対し、𝑃 𝑠𝑖 𝑃𝐴𝑖 と𝑃 𝑠𝑗 𝑃𝐴𝑗 が互いに影響しない • where 𝑃𝐴𝑖 は𝑠𝑖 のcausal graph上の親 Sは因数分解できる • 𝑃 𝑠1 , 𝑠2 , ⋯ , 𝑠𝑁 = ς𝑁 𝑖 𝑃 𝑠𝑖 𝑃𝐴𝑖 • Principle 1と2から、Sは3つの属性がある SはUから分離できる(𝑆 ⫫ 𝑈)。Uを摂動しても、Sに影響しない 𝑠1 , 𝑠2 , ⋯ , 𝑠𝑁 は互いに独立 学習できたSがタスクに対してcausally sufficient(全ての独立変数を説明できる) 8
提案手法の全体図 • Causal Intervention ModuleでSとUを分離 • Causal Factorization Moduleでcausal factorsを因数分解 • Adversarial Mask Moduleでcausally sufficientなrepresentationを実現 9
提案手法 - Causal Intervention Module • Sは入力データの摂動に対して不変であることから、SとUを分離できる フーリエ変換は、位相成分がhigh-levelを保存、振幅成分がlow-levelな統計情報を 保存 𝑂 • ℱ 𝓍 𝑂 = 𝒜 𝓍 𝑂 × 𝑒 −𝑗×𝒫 𝓍 • where 𝒜 𝓍 𝑂 =振幅成分,𝒫 𝓍 𝑂 =位相成分 提案手法は、振幅成分を変化させ、位相成分を不変とするフーリエ変換で入力デー タを摂動 • 𝒜መ 𝓍 𝑂 = 1 − 𝜆 𝒜 𝓍 𝑂 + 𝜆𝒜 𝓍 ′ 𝑂 • where 𝓍 ′ 別のsource domainからの任意のデータ 𝑂 𝑎 −1 𝑎 𝑎 𝑂 −𝑗×𝒫 𝓍 መ • 𝓍 =ℱ ℱ 𝓍 , ℱ 𝓍 =𝒜 𝓍 ×𝑒 10
提案手法 - Causal Intervention Module • 摂動前後のrepresentationは不変なSを生成するgeneratorを学習 Representation: 𝑟 = 𝑔(𝑥) ො ∈ ℝ1×𝑁 1 𝑁 𝑜 𝑎 最適化目標:max𝑔ො σ𝑁 𝐶𝑂𝑅 𝑟 ǁ 𝑖=1 𝑖 , 𝑟𝑖ǁ where 𝑟𝑖ǁ 𝑜 , 𝑟𝑖ǁ 𝑎 はz-scoreで正規化しrepresentation Uと独立するSのrepresentationを取得 11
提案手法 - Causal Factorization Module • causal factorの各要素は互いに独立 最適化目標: 1 σ𝑖≠𝑗 𝐶𝑂𝑅 min𝑔ො 𝑁(𝑁−1) 𝑟𝑖ǁ 𝑜 , 𝑟𝑗ǁ 𝑎 , 𝑖 ≠ 𝑗 • 𝑟𝑖𝑜 , 𝑟𝑗𝑎 の共分散行列が単位行列に近づけることで最適化できる • ℒ𝐹𝑎𝑐 = 1 2 𝑪−I • where 𝐶 = 2 𝐹 𝑟𝑖ǁ 𝑜 ,𝑟𝑗ǁ 𝑎 𝑟𝑖ǁ 𝑜 𝑟𝑖ǁ 𝑎 , 𝑖, 𝑗 ∈ 1,2, ⋯ , 𝑁 12
提案手法 - Adversarial Mask Module • 目標:学習したrepresentationの各dimensionがタスクに貢献(causally efficient) 直接representationでタスクを推定してlossを計算するのは、全てのdimensionが貢献 する保証がない 学習できるmasker 𝑤で、各dimensionの貢献度を推定 ෝ • 𝑚 = 𝐺𝑢𝑚𝑏𝑒𝑙-𝑆𝑜𝑓𝑡𝑚𝑎𝑥 𝑤 ෝ 𝑟 , 𝑘𝑁 ∈ ℝ𝑁 • m: superior dimensions, 1-m: inferior dimensions 𝑠𝑢𝑝 • ℒ𝑐𝑙𝑠 = ℓ ℎ1 𝑟 𝑂 ⊙ 𝑚𝑂 , 𝑦 + ℓ ℎ1 𝑟 𝑎 ⊙ 𝑚𝑎 , 𝑦 𝑖𝑛𝑓 • ℒ𝑐𝑙𝑠 = ℓ ℎ 2 𝑟 𝑂 ⊙ 1 − 𝑚𝑂 , 𝑦 + ℓ ℎ 2 𝑟 𝑎 ⊙ 1 − 𝑚𝑎 , 𝑦 13
提案手法のLoss関数 𝑖𝑛𝑓 𝑠𝑢𝑝 • 𝑔ො、ℎ1 、ℎ2 を学習する場合、ℒ𝑐𝑙𝑠 とℒ𝑐𝑙𝑠 をminimize 𝑠𝑢𝑝 𝑖𝑛𝑓 min𝑔,ො ℎ1,ℎ2 ℒ𝑐𝑙𝑠 + ℒ𝑐𝑙𝑠 + 𝜏ℒ𝐹𝑎𝑐 • 𝑖𝑛𝑓 𝑠𝑢𝑝 𝑤を学習する場合、ℒ ෝ 𝑐𝑙𝑠 をminimize、ℒ 𝑐𝑙𝑠 をmaximize 𝑠𝑢𝑝 𝑖𝑛𝑓 min ℒ𝑐𝑙𝑠 − ℒ𝑐𝑙𝑠 ෝ 𝑤 14
Digits-DGの実験結果 • domain-invariant representation based methods: CCSA, MMD-AAE • causal intervention moduleのみ: FACT • 提案手法の有効性を確認 15
PACSの実験結果 • 異なるbackboneでも提案手法を検証 • Photo domainでの結果がSOTAでない:画像の質が悪く、causal情報が不 十分と考えられる 16
Office-Homeの実験結果 • Challengingなデータセットでも提案手法の有効性を確認 17
Ablation Study • Causal Intervention (CInt.) module, Causal Factorization (CFac.) module and Adversarial Mask (AdvM.) moduleの効果を検証 18
Visual Explanation • 最後のconv layerの出力をGrad-camで可視化 • 提案手法は、よりsemantic(category-related)なところに注目 19
Independence of Causal Representation • Sの各要素が独立するかを評価 Sの共分散行列が対角行列になっているのか: 𝐶 2 𝐹 − 𝑑𝑖𝑎𝑔(𝐶) 2 𝐹 • 提案手法は、異なるbackboneでも有効性を確認 20
Representation Importance • Representationの各dimensionがタスクに貢献するか(causally efficient)を 評価 classifier第一層の重みで、各dimensionの重要度を評価 • 提案手法は、各dimensionの重要度が高く、バラツキも少ない 21
Parameter Sensitivity • ハイパラ𝜏、𝑘を検証 • 比較的にハイパラに敏感でない 22
まとめ • 因果関係を利用したDG手法を提案 摂動した入力データから、domain不変なCausal factorとnon-causal factorを分離し、 causal representationを学習 representationを互い独立する因数に分解することで、 noiseや間違ったcausal factorの要素を排除 タスクへの貢献度が高いと低いcausal factorの要素を推定し、 representationの汎 化性を更に向上 23