【DL輪読会】FedTabDiff: Federated Learning of Diffusion Probabilistic Models for Synthetic Mixed-Type Tabular Data Generation

2.6K Views

March 29, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] FedTabDiff: Federated Learning of Diffusion Probabilistic Models for Synthetic Mixed-Type Tabular Data Generation Takeyuki Nakae http://deeplearning.jp/

2.

書誌情報 • 概要 テーブルデータを作成するDDPMと連合学習を組み合わせたもの • 実装 https://github.com/sattarov/FedTabDiff 2

3.

研究の背景 政府や地方自治体が様々なデータ(表データなど)を公開してデータ活用を進めている。 リスク: • データ漏洩 • データからの個人情報特定 特に現実世界の金融や保険データは、プライバシー・機密情報を扱うため、リスクが高い 情報漏洩のイメージ図(stable-diffusion) 個人情報特定のイメージ図(stable-diffusion) 3

4.

研究の背景 情報漏洩を軽減する対策: 表データの生成 課題 1. 多くの特徴量が存在する 2. 特徴量間の相関や依存関係が存在する 3. 特徴量の分布が不均衡である ↑最近の生成モデルであれば、膨大なデータを学習することで性能の良い 生成データを作成できる しかし現実では、 現実世界の金融や保険データではデータを共有して、一つに集約する などができない問題がある。 Generator Generator 4

5.

研究の背景 最近の生成モデルであれば、膨大なデータを学習することで性能の良い生成データを作成できるが、 現実世界の金融や保険データではデータを共有できない問題がある。 • 上↑の解決策、 学習データを様々なデバイスに分散化した上で、中央サーバーの下でAIモ デルを共同学習するFederated Learning(FL)が候補としてあげられる。 • Federated Learning(連合学習)のメリット 直接的にデータを送受信しない。 本研究では、DDPMとFLを融合する。 ↓すると… データプライバシーの向上+表データを生成できる新しい連合学習フレー ムワークができる。(提案する) 共有 共有 5

6.

事前知識: DDPM • 完全なノイズからノイズを徐々に除去することでデータを作成する方法 ノイズを徐々に除去する「ノイズ除去モデル𝑝𝜃 (𝑥𝑡 |𝑥𝑡−1 )」を利用することで、本物の画像𝑥0 に近 いデータを生成することができる。 ノイズがかかった画像からノイズ付与前の画像がどのようなものだったかを予測することで行う ※ 有名な例: Stable-diffusion ※ただしノイズから直接元画像を予測できないので、予測しやすいように徐々にノイズを付与する。このた めどれくらいノイズを付与したかのノイズレベルが存在する。 6

7.

事前知識: 連合学習 • 一般的な学習 AIモデルは通常は一つのサーバーなどでモデルを学習 • 連合学習(Federated Learning) 複数のサーバーや個人デバイス(これらをクライアントと呼ぶ)ある個別の データを、それぞれのクライアントが学習する。 ↓ 中央サーバーがモデルを管理し、最後に一つのモデルとして集約 メリット • データの送受信を行う必要がないため、データプライバシーを保護し ながら、データの活用が可能になる。 デメリット • 様々なタスクに応用できるが、これらのモデルを適切に訓練し、適用 するためには、専門的な知識と技術が必要。 共有 共有 連合学習のイメージ 7

8.

手法: ddpm 拡散モデルは以下のステップに分けて説明される。 • データにノイズを付与するForward Process • 付与したノイズを除去するReverse Process • Forward Processの数式は以下の(1)。 以下の数式で𝑥0 は初期データ(元のデータ)、𝛽𝑡 はプロセスtにおけるノイズレベル 初期データから完全なノイズになるまでの数式→ ノイズレベル𝛽𝑡 の数式→ 8

9.

手法: ddpm 拡散モデルは以下のステップに分けて説明される。 • データにノイズを付与するForward Process • 付与したノイズを除去するReverse Process • Reverse Processの数式は以下の様に表される。 このため推定にはパラメータ𝜃のモデルが訓練される。 除去するノイズの平均 値の推定方法→ そしてノイズ成分であるの𝛼𝑡 , 𝛽𝑡 , 𝜖𝜃 (𝑥𝑡 , 𝑡)の推定を行う→ そしてノイズ除去を行うパラメータ𝜃のモデルの最適化は最終的に以下の様になる 9

10.

手法: 連合学習(Federated Averaging) 概要 • 連合学習は様々なクライアント(Client 𝝎𝒊 )上にあるデータを、クライアント上のモデル(Local Model)が学習し、それを中央のサーバー(Server)に集約する。 𝝃 • 中央のサーバーでモデルを更新しクライアントに再度共有(𝜽𝒊 )する 10

11.

手法: 連合学習(Federated Averaging) 数式 𝜉 • 中央サーバーでモデルを更新、クライアントに再度共有(𝜃𝑖 )する モデルのパラメータを中央のサーバーに集約する時は最終的に以下の数式(4)となる 上記の数式は分散化されたモデル更新の加重平均を計算している。 𝐶:クライアントの数(それぞれを𝜔𝑖 ) 𝒟 = 𝐷𝑖 𝐶𝑖=1 :サブデータセット(対応するクライアント𝑖とデータ𝑖がデータを共有できる)※ 𝑟: 最適化の回数(通信ラウンド)。𝑟 = 1, . . 𝑅 最適化する度にクライアントへモデルを共有する。 各ラウンドにおいて、クライアント𝜔𝑖,𝑟 ⊆ 𝜔𝑖 𝐶𝑖=1 のサブセットのように一つ選択される。 ※サブデータセットの𝐷1 と𝐷2 は分布が異なる可能性がある 11

12.

実験の設定: データセット 使用したデータセット • Philadelphia City Payments Data(フィラデルフィア市支払いデータ) 58の異なる市の部署から作成された合計238894件のデータセット non-iid split(クライアントごとの分割基準): doc ref no prefix definition (文書参照番号接頭辞定義) • Diabetes Hospital Data(糖尿病病院データ) 米国の130の病院によって収集された糖尿病の臨床治療記録 1999年から2008年: 合計101767件 non-iid split(クライアントごとの分割基準): age(患者の年齢グループ) なおデータの変換は、数値はsklearnの変換を、カテゴリは正弦波位置埋め込みが利用される。(なんで正弦波 位置埋め込み?) 12

13.

評価指標の説明 • fidelity(忠実度) この評価指標は、合成データが実データをどれくらい再現できているかを評価する。 評価は行方向と列方向二つの観点で評価される。 列の評価は、合成データと実データセットの対応する列間の類似度が評価される。 数式: 𝑥 𝑑 は実データの列d 𝑠 𝑑 は合成データの列d • 数値データは、コルモゴロフ・スミルノフ統計量(KSS)で評価 →データの分布の一致度を評価 • カテゴリデータは全変動距離(TVD)で差異を定量化 →カテゴリデータの出現率の一致度を評価? 13

14.

評価指標の説明 • fidelity(忠実度) この評価指標は、合成データが実データをどれくらい再現できているかを評価する。 評価は行方向と列方向二つの観点で評価される。 行の評価は、一つの表に含まれる列ペア間の相関に注目。 数式: 𝑥 𝑎,𝑏 は実データの列d 𝑠 𝑎,𝑏 は合成データの列d 数値特徴: ペア間のピアソン相関(PC)で評価 カテゴリ特徴: 属性aとbのカテゴリー・ペアにわたって計算される。 行の忠実度は、全ての列のペアで評価され、平均して評価される。 この計算は合成データの列aと列b、実データの列aと列bの相関をそれぞれ計算しを比較している 14

15.

評価指標の説明 • utility(忠実度) この評価指標は、合成データがどれだけ元データに忠実かを評価する指標である。 具体的には合成データをモデルに学習し、実データでの精度を評価することで定量的に評価 本論文では実データの学習データで生成モデルを学習し、そこから出力さ れる生成データS(Train)で機械学習モデルを学習。 評価は実のテストデータセットS(Test)で行われる。 補足: 使用しているモデル • ランダムフォレスト • 決定木 • ロジスティック回帰 • Ada Boost • ナイーブ・ベイズ 評価指標は(7)の数式となり、Φはスコアを表す。 Θ𝑖 は𝑖番目の分類器(機械学習モデル)の精度を表す。 15

16.

評価指標の説明 • coverage(カバレッジ) この評価指標は合成データが、実データのカテゴリの多様性をどれくらい再現しているかどうかを評 価する。 カテゴリ列に適用する場合、𝐶𝑠𝑑 ÷ 𝐶𝑥𝑑 によって反映されたカテゴリの割合を計算する。 𝐶𝑥𝑑 : 元データのカテゴリのユニーク数 𝐶𝑠𝑑 :合成データのカテゴリのユニーク数 数値列の場合、以下の数式 ※これは実データのデータの範囲と、合成データの範囲(すなわち、その最小値と最大値)が、どれくらい密接に整合性し ているかを評価する。 16

17.

評価指標の説明 • privacy(プライバシー) この評価指標は、合成データが実データの距離とどれほど近いかを評価するものである。 評価はDCR(Distance to Closest Records)が利用される。 DCRは、合成データ点𝑠𝑛 から真正データ点𝑋までの最近接距離として決定される。 数式: 提案手法のFedTabDiffの有効性は、従来のデータ生成手法(非連帯シナリオ)と比べて評価される。 評価について: 非連帯シナリオでは、各クライアントが独立してFinDiffモデルを学習する。 これらのモデルによって生成されたデータの品質は、クライアント後に分割したデータとデータセット全体との関係性も 含めて評価される。 この分析では、多様なクライアントのデータの学習を統合した結果、データ表現を強化する上での連 携モデルの有効性が評価するのに使う。 17

18.

実験結果 • 結果は以下の表のようになった。 𝜔1 , … , 𝜔5 はモデルを共有しないで、クライアント上のデータのみで学習したもの。 FedTabDiffは提案手法。 18

19.

実験結果 • Fidelity どちらのデータセットでも提案手法の方が性能は高い。 • Utility Philadelphia: 提案手法 Diabetes: 𝜔4 上のデータセットで学習したモデル Philadelphiaで 圧倒的に性能 が良い理由は、 このデータ セットでは データのばら つきが大きい ためだと考え られる。 19

20.

実験結果 • Coverage Philadelphia: 提案手法 Diabetes: 𝜔3 上のデータセットで学習したモデル ※提案手法の方が再現度が低いように見えるが、 既存手法は、過剰に少ないカテゴリをランダムに一様分布で生成する可能性があるため(何故?)、カバレッジのスコアを大きく 見積もる可能性があり、一方で提案手法は、少ないカテゴリを生成しない可能性がある。 20

21.

実験結果 • privacy Philadelphia: 提案手法 Diabetes: 提案手法 以下の表では、小さい方が良いとされている。 しかし大きいと非現実なデータを生成するのに対して、 小さいと元データと同じデータを生成することになるので、一概どちらが良いとは言い切れない部分もある。 21

22.

実験結果 右の図は データとスコアの関係を記載したグ ラフ Client trainは学習に使用したデータ Client evalは評価に使用したデータ allは全てのデータが学習・評価に与 えられ、 𝜔1 , … , 𝜔5 は各クライアントに与えら れたデータのみを学習・評価に与え られる。 連合学習よりデータ全体を一つ のモデルに学習させるほうがス コア的には良かった。 提案手法 提案手法 22

23.

まとめ まとめ • 本手法はデータの機密性を維持しながらデータを生成する研究である。 • FedTabDiffは様々なデータセットにおいて良いパフォーマンス指標を示し、多様なシナリオにお ける有効性を証明した。 思ったこと • データの性能で言うなら、一つのデバイス・サーバーでモデルを学習させた方が性能は良い →一つのサーバにまとめてデータを生成するほうが良さそう • データの代わりにモデルのパラメータを共有することでプライバシーを守っているといえるが…? →敵対的攻撃(ホワイトボックス攻撃)のことを考慮すると、モデルの共有も危険なのでその対策も必 要 • 既存手法は、過剰に少ないカテゴリをランダムに一様分布で生成する可能性があると書かれてい たが何故か? 23

24.

Appendix: 実験の設定: モデルアーキテクチャ • • • • 層の数: 4層(各層は1024個のニューロン) ミニバッチサイズ: 512 通信ラウンド𝑅 = 1000まで学習(1000回モデルを各クライアントのローカルモデルに共有) 最適化方法: Adam optimizer 𝛽1 = 0.9, 𝛽2 = 0.999 • 拡散ステップ数𝑇 = 500 • 初期学習率𝛽𝑠𝑡𝑎𝑟𝑡 = 0.0001、最終学習率𝛽𝑒𝑛𝑑 = 0.02を • 各カテゴリカル属性は2次元の埋め込みに対応する。 分散先のローカルモデル𝜃𝑖𝜔 に対しては、 20回のモデル最適化更新を行った後、そのローカルモデル 𝜃𝑖𝜔 に対して中央のモデルの更新を行う。 →20回ローカルモデルを更新→中央サーバで集約→中央サーバからローカルモデルを更新 クライアントは5つのため𝜆 = 5とする。 24