RMNet_talk_CausalityDiscussionGroup

117 Views

April 12, 23

スライド概要

profile-image

機械学習の技術者/研究者です

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

Regret Minimization for Causal Inference on Large Treatment Space (AISTATS’21) 2023/4/12 Akira Tanimoto1,2,3, Tomoya Sakai1,3, Takashi Takenouchi4,3, Hisashi Kashima2 @akira_tanimoto (1. NEC, 2. Kyoto Univ., 3. Riken AIP, 4. Future Univ. Hakodate) © NEC Corporation 2021

2.

Contents ◆ Background: Association does not imply causation ◆ Remark 1/2: Difference b/w Supervised ML and Causal Inference ◆ Remark 2/2: Causal Inference for Decision-making ◆ Problem Setting ◆ Theory ◆ Method ◆ Intuition ◆ Experiment ◆ Result ◆ Summary 2 © NEC Corporation 2021

3.

Background: Association Does Not Imply Causation Decision-making based on “What-if” prediction needs consideration of causality ◆ We aim at maximizing the outcome by choosing the best action under unknown/ uncertain mechanism. ̂ a), ◆ Learn the outcome predictor f(x, ̂ , choose the best-predicted action a* and then observe the actual outcome y Covariate x Mechanism f( ⋅ , a) Outcome Action ̂ ≃ f*(x, a*) ◆ Conventional supervised learning is not sufficient because ◆ Those who prescribed ⇏ Prescribing have a poor prognosis won't work. ◆ ∵ Biased logging policy μ(a | x) ◆ aka propensity ◆ → Need causal inference (?) 3 © NEC Corporation 2021 Age: 30 Not prescribed μ(a | x = ) ≠ Uniform Age: 70 hyperglycemia Prescribed μ(a | x = ) y

4.

Remark 1/2: Difference b/w Supervised ML and Causal Inference Only have partial (bandit) feedback in causal inference Supervised ML Feature/ Image Prediction x f( ⋅ , y)̂ Loss ℓ Causal Inference Action Covariate x f( ⋅ , a) “Dog” Outcome y ya= Potential Outcome ya= “Cat” “Monkey” ya= “Pig” ya= N 1 (n) LCE = − yi log f î ∑ N∑ n i Sum over classes 4 © NEC Corporation 2022 N 2 1 1 (n) ̂ (n) LIPW = y − f a ) ( (n) (n) ∑ |x ) ̂ N n μ(a Inverse Probability Weighting for de-biasing propensity

5.

Redefine utility / reward (what we want to maximize) as outcome ◆ Decision-making is (or should be) maximization of expected utility (according to the Expected Utility Theory) ■ Von Neumann–Morgenstern utility theorem ■ Iff the preference over distributions (actions) ≿ satisfies some axioms (completeness, transitivity, continuity, and independence), → ℝ such that for any two there exists a utility function u : lotteries (distributions over ) P, Q, P≿Q⇔ ◆ ◆ y∼P[u(y)] ≥ y∼Q[u(y)] Decision-making can be seen as maximization of Redefine y ■ By modeling ← u(y) and maximize [y] [y | x, a] =: f*(x, a) y[u(y)] Covariate Outcome y x Propensity μ(a | x) 𝔼 © NEC Corporation 2022 𝔼 5 Preference of Decisionmaker Utility u a Action Redefine utility as outcome Covariate x Propensity μ(a | x) Outcome f*(x, a) a Action 𝒴 𝔼 𝔼 𝒴 𝔼 Remark 2/2: Causal Inference for Decision-making y

6.
[beta]
Problem Setting
Evaluate models via the outcome of a “plug-in” policy
◆ Given training data D
◆ Train a model

(n)

(n)

(n)

= {(x , a , y )}n

̂ a) ≃ [y | x, a]
f(x,

f̂

̂ | x) with
◆ Optimize the decision policy π(a
◆ Deterministic policy π

D

̂f = arg min L( f; D)

̂ a)
(a | x) = arg max f(x,

f̂

f

a∈

̂ a)
̂
π
:=
arg
max
f(x,
Or
stochastic
one
a∼π [
◆
]
π∈Π

Optimize decision policy

= {π ∣ π(a | x) ≤ 1/k ∀a, x}
1/k if rank( f(x, a); {f(x, a′)}a′) ≤ k
f
πk (a | x) :=
{ 0 otherwise,

◆ E.g., the top-k policy sp. Πk
◆

◆ Evaluate the value V (π)̂

:=
:=

[ f*(x, a)]
̂
π(a∣x)p(x)

π*(a∣x)p(x) [ f*(x, a)]

𝒜

𝔼

◆ Or the regret vs. V (π*)




𝔼

𝔼

𝔼

© NEC Corporation 2021
𝔼

6

Train

= max V(π)
π∈Π

π ̂ := arg max
Evaluate

π∈Π

̂ a)
f(x,
a∼π [
]

RegretΠ := V(π*) − V(π)̂

7.

Theory Classification error of relatively good/bad action matters Proposition: Regret is bounded with the geometric mean of the MSE and a classification error rate. | | RegretΠk( f ) ≤ ERuk( f ) ⋅ MSEu( f ) k u ◆ Since ERk ( f ) ERuk( f ) := x [| 1 I ((rank(ya) ≤ k) ⊕ (rank( f(x, a)) ≤ k)) . ∑ | a∈ ] Classi cation of “whether the action is in the top-k?” u MSE ( f ) := x 1 ∑ a∈ 2 [(y − f(x, a)) ] ya|x a u ≤ 2 , a “uniform” MSE MSE ( f ), which is often evaluated in the causal inference of | | > 2 setting actually bound the decision-making performance. ◆ The objective of causal inference is justified in terms of decision-making. ◆ This bound is tight (Regret can be almost equal to the r.h.s. under some | ◆ 😢 The decision-making performance is hard to guarantee when large | | , k, u ERk ( f ), MSE ( f ) fixed) 𝒜 𝒜 𝒜 𝔼 𝒜 𝔼 𝔼 𝒜 fi 𝒜 𝒜 𝒜 © NEC Corporation 2021 u | /k (unrealistically small MSE required) u ◆ 😃 We can further improve the decision-making performance by minimizing ERk ( f ) also 7 u

8.

Theory Classification error of relatively good/bad action matters Proposition: Regret is bounded with the geometric mean of the MSE and a classification error rate. | | RegretΠk( f ) ≤ ERuk( f ) ⋅ MSEu( f ) k ERuk( f ) := x [| 1 I ((rank(ya) ≤ k) ⊕ (rank( f(x, a)) ≤ k)) . ∑ | a∈ ] Classi cation of “whether the action is in the top-k?” u MSE ( f ) := x 1 ∑ a∈ 2 [(y − f(x, a)) ] ya|x a u u MSE ( f ), which is often evaluated in the causal inference of , ageneral “uniform” MSE ◆ Since ERk ( f ) ≤In2the policy space, actually bound the decision-making performance. u ̂ u ̂ RegretΠ( f ) ≤ | | ⋅ MSE ( f ) ⋅ ERΠ( f )̂ | | > 2 setting ◆ The objective of causal inference is justified in terms of decision-making. 2 u | (Regret can equal to the r.h.s. under some ◆ This bound is tight ̂ where ERΠ ( f )be:=almost π*(a ∣ x) − π(a ∣ x) ) x ∑( [ a∈ ] ◆ 😢 The decision-making performance is hard to guarantee when large | | , k, u ERk ( f ), MSE ( f ) fixed) 𝒜 𝒜 𝒜 𝒜 𝔼 𝒜 𝔼 𝒜 𝔼 𝔼 𝒜 fi 𝒜 𝒜 𝒜 © NEC Corporation 2021 u | /k (unrealistically small MSE required) u ◆ 😃 We can further improve the decision-making performance by minimizing ERk ( f ) also 8 u

9.

Method Minimize the geometric mean of the MSE and error rate + representation balancing(by the Wasserstein distance) Factual outcome ˜ ER g( f ) ⋅ MSE( f ) + ℜ( f ) ◆ Regret Minimization Network (RMNet) minimizes L( f ) = ◆ Training data D Error rate +α Regularizer ⋅ DWass ({ϕ(xn, an)}n, {ϕ(xn, anu)}n) Representation balancing regularizer Factual action 1 ER g( f ) = {s (n) log v (n) + (1 − s (n))log(1 − v (n))} where ˜ ◆ | Ntr | ∑ n ◆ ◆ ◆ s := I(y ≥ g(x)), v := σ( f(x, a) − g(x)) (n) Random action drawn from Unif( ) (n) g(x) ≃ [y | x] is a baseline estimator trained beforehand with {(x , y )} DWass(p1, p2) := sup h:1−Lipschitz ∫ h(z)(p1(z) − p2(z))dz 𝒜 𝔼 9 © NEC Corporation 2021 𝒵 L ≤ L + α ⋅ DWass (provided that ϕ has its inverse ϕ −1 ) a a u g ϕ ŷ Φ u Φ Representation h yâ MSE ˜ ER g Cross Entropy MSE DWass for inducing similarity of dist. of Φ and Φu Regret Minimization Network architecture ◆ Representation balancing (DWass): the loss on the uniform action dist. is bounded by the loss on the data dist. + Wasserstein distance: u x y

10.

Intuition Consider the classification error compared to a personalized baseline in addition to a regression error (a “uniform” MSE) ◆ Aim at minimizing the Regret := the gap between the predicted-best action and the true best action ◆ MSE does not tell you whether fA or fB is better. Error rate = 0 Error rate = 6/9 yâ = fA(x = , a) Predicted ̂̂ ya* 9.5 Same MSEs ̂̂ ya* ȳ ≃ 9.5 ȳ 7.125 ȳ = g(x) yâ = fB(x = , a) 7.125 Classi cation 4.75 4.75 | y [ a a x] Regret Error 2.375 Regret 2.375 𝔼 Actual 0 10 fi © NEC Corporation 2021 0 2.375 Better 4.75 y a y y ȳ a*̂ a* 7.125 9.5 0 0 2.375 ya*̂ ȳ 4.75 7.125 ya* 9.5 ya

11.

Experiment Semi-synthetic data made by sub-sampling of “complete” dataset ◆ Source data: SGEMM GPU performance dataset a0 ◆ For all combinations of 14-dimensional GPU kernel parameters (N=241k), four computation runs are recorded, and the inverse of this average is the target ya. ◆ 14 dimensions are separated into 3–6 dimensional action a ◆ ⊤ ⊤ ⊤ p(a | x, y) ∝ exp( − | 10(y − [x , a ] w | ) a2 Train Test © NEC Corporation 2021 1 0 1 0 1 0 1 (0, 0, 0) 0.5 – 0.2 – 0.4 0.7 0.1 0.0 (0, 1, 0) 1.0 – 0.1 – 0.5 0.5 0.2 0.3 (0, 1, 1) 0.3 0.4 – 0.4 0.8 – 0.4 – (1, 1, 0) 0.8 0.7 – 0.2 0.9 – 0.3 – (1, 0, 1) 0.4 0.7 0.6 0.9 0.3 0.0 0.1 0.1 (1, 1, 1) 0.8 0.1 (0, 0, 1) 0.9 0.7 0.2 (1, 0, 0) 0.1 0.9 – – 0 1 0 ◆ The outcome has a pseudo-correlation with ⊤ ⊤ ⊤ a random 1-dimensional representation [x , a ] w ↑Randomly generated ◆ Validation is based on the complete data 11 0 a1 and other features x. ◆ Biased sub-sampling only for the training data Ya x 1 0 1 0.3 1.0 1.0 0.4 0.6 – 0.2 0.8 1.0 – 1.0 0.3 0.6 0.6 0.9 Complete dataset

12.

Result Proposed method performs the best, the classification error matters 表 1 半人工データ実験における最終的な意思決定のパフォーマンス及び回帰予測精度,判別精度の結果.訓練/テストデータの分 割に関して 10 回試行した際の平均と標準誤差を示した.各設定及び指標において最高の結果を太字,次点を下線付きで示した. (Normalized) plug-in policy value ! (π 正規化プラグインポリシー価値 V k=1 ) f |A| Method → OLS 𝒜 ± 0.01 ± 0.13 ± 0.05 ± 0.09 ± 0.06 ± 0.07 0.61 ± 0.05 ± 0.07 ± 0.04 ± 0.13 ± 0.06 ± 0.10 ± 0.14 ± 0.09 −0.10 0.33 0.33 0.13 0.30 0.32 0.39 0.38 ± 0.13 0.61 ± 0.04 ± 0.05 ± 0.04 ± 0.10 ± 0.07 ± 0.07 ± 0.10 ± 0.06 −0.01 0.38 0.39 0.04 0.37 0.45 0.35 0.45 ± 0.10 0.51 ± 0.06 ± 0.05 ± 0.02 ± 0.09 ± 0.05 ± 0.05 ± 0.05 ± 0.05 1.12 1.03 0.59 1.06 0.78 0.75 0.78 6.08 1.89 0.87 0.64 1.05 0.83 0.64 0.80 10.13 1.70 0.93 0.64 1.15 0.82 0.74 0.87 8.47 0.76 0.81 0.85 8 16 32 64 5.86 1.07 0.63 1.63 0.84 0.74 0.86 2.42 0.221 0.214 0.211 0.222 0.211 0.212 0.210 0.210 0.116 0.114 0.113 0.116 0.113 0.114 0.113 0.113 0.061 0.059 0.059 0.060 0.059 0.059 0.058 0.058 0.031 0.030 0.030 0.031 0.030 0.029 0.030 0.029 0.204 0.109 0.055 0.029 0.75 →→ 0.68 ± 0.04 ± 0.20 64 → 12 → RMNet (proposed) ± 0.08 −0.08 0.33 0.39 0.13 0.48 0.25 0.39 0.29 8 → ◆ ± 0.15 64 ERuk=1 → ◆ RF kNN → BART → Multi-head DNN → Single-head DNN → CFRNet RankNet −0.04 0.24 0.35 −0.05 0.40 0.28 0.50 0.35 32 →→ ◆ 16 → ◆ 8 MSEu 16 32 u →Although the proposed method does not necessarily minimize ER for k=1, it is superior in ERk=1 and also in ods, considerations, and applications in the journal of 表 2 アブレーション. decision performance. thoracic and cardiovascular surgery. The Journal of ! (π f ) 正規化プラグインポリシー価値 V f and cardiovascular surgery, Vol. 150, No. 1, k=1 thoracic →Classification error ERSynthetic is more consistent with final performance than (uniform) regression error V(π ) Semi-synthetic pp. 14–19, 2015. IPMmethods MSE ER Bilinearinference |A| = 32 |A| = 64 →Existing of causal are not necessarily high performance. [6] Cedric Nugteren and Valeriu Codreanu. Cltune: A †! ! ! 0.77 ± 0.04 0.61 ± 0.04 0.51 ± 0.06 generic auto-tuner | 0.61 | is,± the →The larger space better the single-head architecture wouldfor be.opencl kernels. In Embed! the — action ! 0.73 ± 0.03 0.05 0.58 ± 0.05 © NEC Corporation 2021 ded Multicore/Many-core Systems-on-Chip (MCSoC), ! ! — 0.55 ± 0.10 0.55 ± 0.05 0.49 ± 0.05

13.

Summary Redefined causal inference from a decision-making perspective, discovered the importance of a certain kind of classification accuracy. ◆ Analyzed the goodness of prediction-based decision making and investigated learning methods that contribute to decision making. ◆ Theoretically confirmed that uniform MSE, which is a common goal in causal inference, contributes to decision-making performance. ◆ On the other hand, we found that a certain kind of classification accuracy is also important. ◆ Proposed a learning method using an objective function that includes the geometric mean of the regression and classification accuracies. ◆ Confirmed superiority on semi-synthetic data. ◆ Empirically, classification accuracy is more important than regression accuracy. 13 © NEC Corporation 2021

14.

Reference ◆ Akira Tanimoto, Tomoya Sakai, Takashi Takenouchi, and Hisashi Kashima. Regret minimization for causal inference on large treatment space. In AISTATS, 2021. ◆ Uri Shalit, Fredrik D Johansson, and David Sontag. Estimating individual treatment effect: generalization bounds and algorithms. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 3076–3085. JMLR. org, 2017. ◆ Jennifer L Hill. Bayesian nonparametric modeling for causal inference. Journal of Computational and Graph- ical Statistics, Vol. 20, No. 1, pp. 217–240, 2011. ◆ Zou, Hao, et al. "Counterfactual Prediction for Outcome-Oriented Treatments." International Conference on Machine Learning. PMLR, 2022. 14 © NEC Corporation 2021 NEC Group Internal Use Only