Regret Minimization for Causal Inference on Large Treatment Space (AISTATS'21) 2023/4/12 Akira Tanimoto1,2,3, Tomoya Sakai1,3, Takashi Takenouchi4,3, Hisashi Kashima2 @akira_tanimoto (1. NEC, 2. Kyoto Univ., 3. Riken AIP, 4. Future Univ. Hakodate)
Contents ◆ Background: Association does not imply causation ◆ Remark 1/2: Difference b/w Supervised ML and Causal Inference ◆ Remark 2/2: Causal Inference for Decision-making ◆ Problem Setting ◆ Theory ◆ Method ◆ Intuition ◆ Experiment ◆ Result ◆ Summary
Background: Association Does Not Imply Causation Decision-making based on "What-if" prediction needs consideration of causality ◆ We aim at maximizing the outcome by choosing the best action under unknown/ uncertain mechanism. ̂ a), ◆ Learn the outcome predictor f(x, ̂ , choose the best-predicted action a* and then observe the actual outcome y Covariate x Mechanism f( ⋅ , a) Outcome Action ̂ ≃ f*(x, a*) ◆ Conventional supervised learning is not sufficient because ◆ Those who prescribed ⇏ Prescribing have a poor prognosis won't work. ◆ ∵ Biased logging policy μ(a | x) ◆ aka propensity ◆ → Need causal inference (?) Age: 30 Not prescribed μ(a | x = ) ≠ Uniform Age: 70 hyperglycemia Prescribed μ(a | x = ) y
Remark 1/2: Difference b/w Supervised ML and Causal Inference Only have partial (bandit) feedback in causal inference Supervised ML Feature/ Image Prediction x f( ⋅ , y)̂ Loss ℓ Causal Inference Action Covariate x f( ⋅ , a) "Dog" Outcome y ya= Potential Outcome ya= "Cat" "Monkey" ya= "Pig" ya= N 1 (n) LCE = − yi log f î ∑ N∑ n i Sum over classes N 2 1 1 (n) ̂ (n) LIPW = y − f a ) ( (n) (n) ∑ |x ) ̂ N n μ(a Inverse Probability Weighting for de-biasing propensity
Redefine utility / reward (what we want to maximize) as outcome ◆ Decision-making is (or should be) maximization of expected utility (according to the Expected Utility Theory) ■ Von Neumann–Morgenstern utility theorem ■ Iff the preference over distributions (actions) ≿ satisfies some axioms (completeness, transitivity, continuity, and independence), → ℝ such that for any two there exists a utility function u : lotteries (distributions over ) P, Q, P≿Q⇔ ◆ ◆ y∼P[u(y)] ≥ y∼Q[u(y)] Decision-making can be seen as maximization of Redefine y ■ By modeling ← u(y) and maximize [y] [y | x, a] =: f*(x, a) y[u(y)] Covariate Outcome y x Propensity μ(a | x) 𝔼 Preference of Decisionmaker Utility u a Action Redefine utility as outcome Covariate x Propensity μ(a | x) Outcome f*(x, a) a Action 𝒴 𝔼 𝔼 𝒴 𝔼 Remark 2/2: Causal Inference for Decision-making y
Problem Setting
Evaluate models via the outcome of a “plug-in” policy
◆ Given training data D
◆ Train a model
= {(x , a , y )}n
̂ a) ≃ [y | x, a]
̂ | x) with
◆ Optimize the decision policy π(a
◆ Deterministic policy π
̂f = arg min L( f; D)
̂ a)
(a | x) = arg max f(x,
̂ a)
a∼π [
Optimize decision policy
= {π ∣ π(a | x) ≤ 1/k ∀a, x}
1/k if rank( f(x, a); {f(x, a′)}a′) ≤ k
πk (a | x) :=
{ 0 otherwise,
◆ E.g., the top-k policy sp. Πk
◆ Evaluate the value V (π)̂
[ f*(x, a)]
π*(a∣x)p(x) [ f*(x, a)]
◆ Or the regret vs. V (π*)

= max V(π)
π ̂ := arg max
̂ a)
a∼π [
RegretΠ := V(π*) − V(π)̂
Theory Classification error of relatively good/bad action matters Proposition: Regret is bounded with the geometric mean of the MSE and a classification error rate. | | RegretΠk( f ) ≤ ERuk( f ) ⋅ MSEu( f ) k u ◆ Since ERk ( f ) ERuk( f ) := x [| 1 I ((rank(ya) ≤ k) ⊕ (rank( f(x, a)) ≤ k)) . ∑ | a∈ ] Classi cation of "whether the action is in the top-k?" u MSE ( f ) := x 1 ∑ a∈ 2 [(y − f(x, a)) ] ya|x a u ≤ 2 , a "uniform" MSE MSE ( f ), which is often evaluated in the causal inference of | | > 2 setting actually bound the decision-making performance. ◆ The objective of causal inference is justified in terms of decision-making. ◆ This bound is tight (Regret can be almost equal to the r.h.s. under some | ◆ 😢 The decision-making performance is hard to guarantee when large | | , k, u ERk ( f ), MSE ( f ) fixed) 𝒜 𝒜 𝒜 𝔼 𝒜 𝔼 𝔼 𝒜 fi 𝒜 𝒜 𝒜 u | /k (unrealistically small MSE required) u ◆ 😃 We can further improve the decision-making performance by minimizing ERk ( f ) also
Theory Classification error of relatively good/bad action matters Proposition: Regret is bounded with the geometric mean of the MSE and a classification error rate. | | RegretΠk( f ) ≤ ERuk( f ) ⋅ MSEu( f ) k ERuk( f ) := x [| 1 I ((rank(ya) ≤ k) ⊕ (rank( f(x, a)) ≤ k)) . ∑ | a∈ ] Classi cation of "whether the action is in the top-k?" u MSE ( f ) := x 1 ∑ a∈ 2 [(y − f(x, a)) ] ya|x a u u MSE ( f ), which is often evaluated in the causal inference of , ageneral "uniform" MSE ◆ Since ERk ( f ) ≤In2the policy space, actually bound the decision-making performance. u ̂ u ̂ RegretΠ( f ) ≤ | | ⋅ MSE ( f ) ⋅ ERΠ( f )̂ | | > 2 setting ◆ The objective of causal inference is justified in terms of decision-making. 2 u | (Regret can equal to the r.h.s. under some ◆ This bound is tight ̂ where ERΠ ( f )be:=almost π*(a ∣ x) − π(a ∣ x) ) x ∑( [ a∈ ] ◆ 😢 The decision-making performance is hard to guarantee when large | | , k, u ERk ( f ), MSE ( f ) fixed) 𝒜 𝒜 𝒜 𝒜 𝔼 𝒜 𝔼 𝒜 𝔼 𝔼 𝒜 fi 𝒜 𝒜 𝒜 u | /k (unrealistically small MSE required) u ◆ 😃 We can further improve the decision-making performance by minimizing ERk ( f ) also
Method Minimize the geometric mean of the MSE and error rate + representation balancing(by the Wasserstein distance) Factual outcome ˜ ER g( f ) ⋅ MSE( f ) + ℜ( f ) ◆ Regret Minimization Network (RMNet) minimizes L( f ) = ◆ Training data D Error rate +α Regularizer ⋅ DWass ({ϕ(xn, an)}n, {ϕ(xn, anu)}n) Representation balancing regularizer Factual action 1 ER g( f ) = {s (n) log v (n) + (1 − s (n))log(1 − v (n))} where ˜ ◆ | Ntr | ∑ n ◆ ◆ ◆ s := I(y ≥ g(x)), v := σ( f(x, a) − g(x)) (n) Random action drawn from Unif( ) (n) g(x) ≃ [y | x] is a baseline estimator trained beforehand with {(x , y )} DWass(p1, p2) := sup h:1−Lipschitz ∫ h(z)(p1(z) − p2(z))dz 𝒜 𝔼 𝒵 L ≤ L + α ⋅ DWass (provided that ϕ has its inverse ϕ −1 ) a a u g ϕ ŷ Φ u Φ Representation h yâ MSE ˜ ER g Cross Entropy MSE DWass for inducing similarity of dist. of Φ and Φu Regret Minimization Network architecture ◆ Representation balancing (DWass): the loss on the uniform action dist. is bounded by the loss on the data dist. + Wasserstein distance: u x y
Intuition Consider the classification error compared to a personalized baseline in addition to a regression error (a "uniform" MSE) ◆ Aim at minimizing the Regret := the gap between the predicted-best action and the true best action ◆ MSE does not tell you whether fA or fB is better. Error rate = 0 Error rate = 6/9 yâ = fA(x = , a) Predicted ̂̂ ya* 9.5 Same MSEs ̂̂ ya* ȳ ≃ 9.5 ȳ 7.125 ȳ = g(x) yâ = fB(x = , a) 7.125 Classi cation 4.75 4.75 | y [ a a x] Regret Error 2.375 Regret 2.375 𝔼 Actual 0 fi 0 2.375 Better 4.75 y a y y ȳ a*̂ a* 7.125 9.5 0 0 2.375 ya*̂ ȳ 4.75 7.125 ya* 9.5 ya
Experiment Semi-synthetic data made by sub-sampling of "complete" dataset ◆ Source data: SGEMM GPU performance dataset a0 ◆ For all combinations of 14-dimensional GPU kernel parameters (N=241k), four computation runs are recorded, and the inverse of this average is the target ya. ◆ 14 dimensions are separated into 3–6 dimensional action a ◆ ⊤ ⊤ ⊤ p(a | x, y) ∝ exp( − | 10(y − [x , a ] w | ) a2 Train Test 1 0 1 0 1 0 1 (0, 0, 0) 0.5 – 0.2 – 0.4 0.7 0.1 0.0 (0, 1, 0) 1.0 – 0.1 – 0.5 0.5 0.2 0.3 (0, 1, 1) 0.3 0.4 – 0.4 0.8 – 0.4 – (1, 1, 0) 0.8 0.7 – 0.2 0.9 – 0.3 – (1, 0, 1) 0.4 0.7 0.6 0.9 0.3 0.0 0.1 0.1 (1, 1, 1) 0.8 0.1 (0, 0, 1) 0.9 0.7 0.2 (1, 0, 0) 0.1 0.9 – – 0 1 0 ◆ The outcome has a pseudo-correlation with ⊤ ⊤ ⊤ a random 1-dimensional representation [x , a ] w ↑Randomly generated ◆ Validation is based on the complete data a1 and other features x. ◆ Biased sub-sampling only for the training data Ya x 1 0 1 0.3 1.0 1.0 0.4 0.6 – 0.2 0.8 1.0 – 1.0 0.3 0.6 0.6 0.9 Complete dataset
Result Proposed method performs the best, the classification error matters 表 1 半人工データ実験における最終的な意思決定のパフォーマンス及び回帰予測精度,判別精度の結果.訓練/テストデータの分 割に関して 10 回試行した際の平均と標準誤差を示した.各設定及び指標において最高の結果を太字,次点を下線付きで示した. (Normalized) plug-in policy value ! (π 正規化プラグインポリシー価値 V k=1 ) f |A| Method → OLS 𝒜 ± 0.01 ± 0.13 ± 0.05 ± 0.09 ± 0.06 ± 0.07
Summary Redefined causal inference from a decision-making perspective, discovered the importance of a certain kind of classification accuracy. ◆ Analyzed the goodness of prediction-based decision making and investigated learning methods that contribute to decision making. ◆ Theoretically confirmed that uniform MSE, which is a common goal in causal inference, contributes to decision-making performance. ◆ On the other hand, we found that a certain kind of classification accuracy is also important. ◆ Proposed a learning method using an objective function that includes the geometric mean of the regression and classification accuracies. ◆ Confirmed superiority on semi-synthetic data. ◆ Empirically, classification accuracy is more important than regression accuracy. 13 © NEC Corporation 2021
