【DL輪読会】Efficiently Modeling Long Sequences with Structured State Spaces

5.1K Views

December 03, 21

スライド概要

2021/12/03
Deep Learning JP:
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] “Efficiently Modeling Long Sequences with Structured State Spaces” Naoki Nonaka http://deeplearning.jp/ 2023/10/10 1

2.

書誌情報 • 会議:ICLR2022 投稿(評価: 8, 8, 8) (本スライドはArxivに投稿されている論文に基づいて作成) • 著者: 2023/10/10 2

3.

概要  長距離の依存関係を持つ系列データの問題に取り組んだ研究  SSM(状態空間モデル)x Deep Learningのアプローチを提案  ベンチマークにて既存手法を大幅に上回る性能を実現 2023/10/10 3

4.

背景 長距離の依存関係(Long-range dependencies: LRD) 依存関係 … T  実世界のデータでは,数万ステップでの推論が必要 (具体例としては,音声や言語情報など)  LRDに取り組んだ深層学習による従来の手法としては, RNN, CNNやTransformerとその改良手法が提案されてきた 2023/10/10 4

5.

背景 LRDに取り組んだ従来手法の利点と欠点 O X RNN ステップごとの 計算量/ストレージが一定 学習に時間がかかる 最適化が難しい(Vanishing gradient) CNN 並列可能で高速に学習できる 逐次学習ではないので 推論時のコストが高い/扱える長さに制限 (Transformer系の手法もCNNとほぼ同じ) 2023/10/10 5

6.

背景 理想的な時系列モデル  各時刻における状態を保持し,推論が可能(recurrence)  並列計算による学習が可能(convolutional)  任意の時間軸適応(微分方程式の性質) 状態空間モデル(State Space Model; SSM) 2023/10/10 6

7.

背景 状態空間モデル 図は[1]をもとに改変  入力,出力,状態の3つの変数からなる数学的モデル  多くの数理モデルの基礎となっているモデル 状態空間モデル x 深層学習の手法は存在しなかった※ ※ 厳密には同一著者の先行研究[1]が不完全ながら取り組んでいる 2023/10/10 7

8.

提案手法: S4 S4: Structured State Space sequence model → 状態空間モデル x 深層学習の手法 S4の導出過程 1. SSMのRecurrent表現とConvolution表現の導出 2. HiPPO行列による連続時間記憶の問題の解決 ※ 3. SSM convolutionカーネル(後述)の計算の効率化 ※ 同一著者の先行研究[1]における工夫と同じ 2023/10/10 8

9.

S4: Recurrent表現とConvolution表現の導出 S4 (SSM): 再帰的な計算と並列学習が可能 連続時間SSM 2023/10/10 離散時間SSM RNN様の再帰的な計算が可能に 畳み込み演算での表現 CNN様の並列計算が可能に 9

10.

S4: Recurrent表現とConvolution表現の導出 離散時間SSM ◼ 間隔Δで離散化 ◼ Bilinear法を使用  離散化により,離散的な入力データを扱えるようになる  RNNと同じく再帰的な処理が可能になる 2023/10/10 10

11.

S4: Recurrent表現とConvolution表現の導出 畳み込み演算での表現 展開 SSMの畳み込み演算 ഥ) を定義 SSM convolution kernel (K 2023/10/10 11

12.

提案手法: S4 S4: Structured State Space sequence model → 状態空間モデル x 深層学習の手法 S4の導出過程 1. SSMのRecurrent表現とConvolution表現の導出 2. HiPPO行列による連続時間記憶の問題の解決 ※ 3. SSM convolutionカーネル(後述)の計算の効率化 ※ 同一著者の先行研究[1]における工夫と同じ 2023/10/10 13

13.

S4: HiPPO行列による連続時間記憶 HiPPO: High-order Polynomial Projection Operators 図は[2]より  直交多項式の重み付き和によって過去の系列を表現  RNNに組み込むと記憶性能が向上する 2023/10/10 14

14.

提案手法: S4 S4: Structured State Space sequence model → 状態空間モデル x 深層学習の手法 S4の導出過程 1. SSMのRecurrent表現とConvolution表現の導出 2. HiPPO行列による連続時間記憶の問題の解決 ※ 3. SSM convolutionカーネル(後述)の計算の効率化 ※ 同一著者の先行研究[1]における工夫と同じ 2023/10/10 19

15.

S4: SSM convolution kernelの計算 SSMの学習の並列化 連続時間記憶の改善 Aの冪乗計算が必要 AはHiPPO行列である必要 HiPPO行列の冪乗計算が必要 2023/10/10 20

16.

S4: SSM convolution kernelの計算 ഥ の計算: 行列Aの冪乗計算を含むため工夫が必要 K  Aを,対角行列Λ + 低ランク行列 p, q (rank=1)  3つの計算工夫を導入 2023/10/10 21

17.

S4: SSM convolution kernelの計算 1. FFTによる冪乗計算の回避 𝑧におけるSSM母関数を定義 (詳細はAppendix C3) 数列𝑎𝑛 に対する母関数 ∞ 𝑓 𝑥 = ෍ 𝑎𝑘 𝑥 𝑘 𝑘=0 𝑧を1の冪根とすると, 1の冪根 ◼ 𝜍 = exp 2𝜋𝑖 𝑛 ◼ ある𝑛に対して 𝑧 𝑛 = 1を満たす𝑧 → 離散フーリエ変換と一致 ഥ を得る SSM母関数で冪乗計算を逆行列計算化 + 逆FFTで K 2023/10/10 22

18.

S4: SSM convolution kernelの計算 2. 対角行列 + 低ランク行列の逆行列計算 Woodbury恒等式を利用 SSM母関数における逆行列計算を効率化 3. Cauchyカーネルによる計算 Aが対角行列のときSSM母関数の計算 = Cauchyカーネルの計算 Cauchyカーネルの計算アルゴリズムを利用 2023/10/10 23

19.

S4 layer 実装上は,系列を受け取り,系列を出力する層となる __init__ 内 forward 内 … Dropout S4 LayerNorm … Input https://github.com/HazyResearch/state-spaces/blob/main/example.py 2023/10/10 24

20.

実験  計算効率  長距離の依存関係の学習  汎用系列モデルとしての性能 2023/10/10 25

21.

実験: 計算効率  LSSL(状態空間モデル系の先行研究)よりも高速・高メモリ効率  (Efficientな)Transformer系と同程度に高速・省メモリ 2023/10/10 26

22.

実験: 長距離の依存関係の学習  Long Range Arena (LRA)  (主にTransformer系の手法を念頭にした) 長距離の依存関係のモデリング性能を評価するためのデータセット []  6つのタスクで構成される  Raw speech classification  Speech Commandデータセット(35クラス,100,503件のサンプル)  話し言葉の音声データの中からキーワードを検出するタスク 2023/10/10 27

23.

実験: 長距離の依存関係の学習 (LRA: 1/4) 1. <LISTOPS> Long ListOps 複数の演算子(MAX, MEAN, MEDIAN, SUM_MOD)の階層構造で 表現された系列から出力となる数字を当てるタスク 2. <TEXT> Byte-level Text classification  IMDbレビューをもとに作成されたデータセット  byte/character-levelで分類 2023/10/10 28

24.

実験: 長距離の依存関係の学習 (LRA: 2/4) 3. <RETRIEVAL> Byte-level Document Retrieval  長い文章を短い表現に圧縮し,文章の類似度を評価するタスク  元データはIMDbのレビュー  系列長は4k(長いものはtruncate, 短いものはpadding) 4. <IMAGE> Image Classification on sequence of pixels  Sequential MNISTのCIFAR-10版  系列長3072 (= 32 x 32 x 3) のサンプルを10クラスに分類 2023/10/10 29

25.

実験: 長距離の依存関係の学習 (LRA: 3/4) 3. <PATHFINDER> PathFinder 画像中の2点が破線でつながっているか判定 入力:32 x 32の画像の系列(=784) 出力:二値(2点がつながっているか) 4. <PATH-X> PathFinder-X PathFinderタスクを128 x 128に拡大した画像で実施 2023/10/10 30

26.

実験: 長距離の依存関係の学習 (LRA: 1/4)  6つのタスク全てで既存手法を大幅に上回る  PathFinder-Xを解けた唯一のモデル 2023/10/10 31

27.

実験: 長距離の依存関係の学習  Long Range Arena (LRA)  (主にTransformer系の手法を念頭にした) 長距離の依存関係のモデリング性能を評価するためのデータセット []  6つのタスクで構成される  Raw speech classification  Speech Commandデータセット(35クラス,100,503件のサンプル)  話し言葉の音声データの中からキーワードを検出するタスク 2023/10/10 32

28.

実験: 長距離の依存関係の学習(Speech: 1/1)  MFCCによる前処理あり:先行研究と同程度の性能  Rawデータでの分類:WaveGANを上回る性能 2023/10/10 33

29.

実験: 汎用系列モデルとしての性能  大規模な生成モデルの学習  CIFAR-10における密度推定  WikiText-103における言語モデリング  自己回帰による推論  CIFAR-10およびWikiText-103での生成速度を比較 2023/10/10 34

30.

実験: 汎用系列モデルとしての性能 大規模な生成モデルの学習/自己回帰による推論  先行研究と同程度の性能を達成  自己回帰による推論の速度は60倍以上高速化 2023/10/10 35

31.

実験: 汎用系列モデルとしての性能 不規則にサンプリングされたデータの扱い  Test時のみ周波数を0.5倍にして評価(右列)  S4では,追加学習なしでも周波数の 変化に対して頑健になっている 2023/10/10 36

32.

結論・まとめ  状態空間モデルにDNNを取り込んだS4モデルを提案  LRAにて既存手法を大幅に上回る性能を実現  汎用系列モデルとしても優れた性能を示す 2023/10/10 37

33.

Reference 1. Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers 2. HiPPO: Recurrent Memory with Optimal Polynomial Projections 2023/10/10 38

34.

Appendix 2023/10/10 39