866 Views
May 24, 22
スライド概要
HiPPO/S4解説 LSTMやTransformerを超える 時系列モデリングの新手法 2022/04/22 株式会社モルフォ 角田良太朗
Overview Warning 本日の内容はゴリゴリの理論です。 中身は非常に美しいのですが、なにせ1時間しかないこともあり 結構な脱落者が発生する危険が高いです。 それを踏まえ、通常の流れには逆らい、まず結果から説明します。 Copyright © 2022 Morpho, Inc. All Rights Reserved 1
Overview Long-Range Arena[1]というタスクが存在する。 • Efficient Transformer系の長距離依存性を統一的に評価するためのベン チマーク。 • トークン数1000~16000の様々なテストデータを用意。 • タスクは6種類: • LONG LISTOPS • BYTE-LEVEL TEXT CLASSIFICATION • BYTE-LEVEL DOCUMENT RETRIEVAL • IMAGE CLASSIFICATION ON SEQUENCES OF PIXELS • PATHFINDER • PATHFINDER-X(←次ページで詳述) Copyright © 2022 Morpho, Inc. All Rights Reserved 2
Overview PATHFINDER-X positive negative ([6]より引用) ([6]より引用) 128x128の画像中の2点が点線でつながっているか二値判定。 ただし、 画像はflattenして入力、 2D-Convの使用禁止。 Copyright © 2022 Morpho, Inc. All Rights Reserved 3
Overview このタスクで2021年11月に大幅なSOTA更新あり。 ([6]より引用) Copyright © 2022 Morpho, Inc. All Rights Reserved 4
Overview このタスクで2021年11月に大幅なSOTA更新あり。 ([6]より引用) S4以外のすべてのモデルは、推論に失敗していた(乱択と同程度) Copyright © 2022 Morpho, Inc. All Rights Reserved 5
Overview 本スライドの目標は、この驚異的なモデルS4とは何者なのかを解明すること。 ([6]より引用) 手法をざっと述べると。。。 Copyright © 2022 Morpho, Inc. All Rights Reserved 6
Overview 1. 時系列解析を状態空間モデルとして次のように定式化 𝑥’(𝑡) = 𝐴𝑥(𝑡) + 𝐵𝑢(𝑡) 𝑦(𝑡) = 𝐶𝑥(𝑡) + 𝐷𝑢(𝑡) (𝑢 𝑡 ∈ ℝ𝐿 :input, 𝑥 𝑡 ∈ ℝ𝑁∗𝐿 : hidden state, 𝑦 𝑡 ∈ ℝ𝐿 : output) A, B, C, D はlearned parameters AはHiPPO matrixとして初期化し、正規行列+low-rank matrixの形のみを 取るよう制限する。 Copyright © 2022 Morpho, Inc. All Rights Reserved 7
Overview 2. この式を離散化&展開することで 𝑦𝑘 = 𝐶𝐴𝑘 𝐵𝑢0 + 𝐶𝐴𝑘−1 𝐵𝑢1 + ⋯ + 𝐶𝐴𝐵𝑢𝑘−1 + 𝐶𝐵𝑢𝑘 𝑦 = 𝐾 ∗ 𝑢 (𝐾 ≔ 𝐶𝐵 + 𝐶𝐴𝐵 + ⋯ + 𝐶𝐴𝐿−1 𝐵 ∈ ℝ𝐿 ) と1D-Convの形にかける。 𝐾が高速計算できれば学習はRNNより高速。 (𝐿回のiterationをせず一発で計算できるので) Copyright © 2022 Morpho, Inc. All Rights Reserved 8
Overview 3.Kを直接計算せず、これのスペクトラム 𝐿−1 𝐹 𝐾 ≔ 𝑗=0 𝐾𝑗 𝜁𝑗 を計算してiFFTで𝐾𝑗 を一括導出したい。 𝐹(𝐾)を一般に𝐾の𝑧変換として求めることを考える。 これはAを前述の形に制限したことから、 Woodbury-Identityを用いて ෨ + 𝐿)で導出可能。 Cauchy Kernel4つの和として記述でき、特に𝑂(𝑁 (おわり) Copyright © 2022 Morpho, Inc. All Rights Reserved 9
Overview 何を言ってるのか全く分からないが、状態空間モデルの方程式を解くのに • 適切な係数空間の設計(HiPPO行列) • convolutionに書き直して高速な行列計算 をしたことで、超長距離依存性を保持できたことがポイントみたい。 実際次で見るように、これらの施策は精度にcriticalに効いている。 Copyright © 2022 Morpho, Inc. All Rights Reserved 10
Overview • 係数行列の初期化の影響(sequential CIFAR10で実験) ([6]より引用) Copyright © 2022 Morpho, Inc. All Rights Reserved 11
Overview • 高速な行列計算アルゴリズムの影響 ([6]より引用) Copyright © 2022 Morpho, Inc. All Rights Reserved 12
Overview 高度で緻密な理論設計がここまで見事に精度に反映されるような deep learning modelは見たことがない! これは解読する価値がありそうだ。。。! Copyright © 2022 Morpho, Inc. All Rights Reserved 13
Overview しかし現段階ではいろいろと謎が多すぎる。 • • • • • • 唐突に出てきた線形方程式は一体? HiPPOって何? 正規行列+low-rankでパラメータを書く動機は? Woodbury-Identityどこで使うんや? Cauchy kernelがなんで絡むん? ・・・・・・ Copyright © 2022 Morpho, Inc. All Rights Reserved 14
Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral) Copyright © 2022 Morpho, Inc. All Rights Reserved 15
Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces First author全部同じ人や。。。 Albert Gu, Karan Goel, Christopher Ré Albert Guさん強すぎ。。。 (ICLR 2022 Oral) Copyright © 2022 Morpho, Inc. All Rights Reserved 16
Overview 色々調べると以下の3本の論文が1セットになっていた。 HiPPO: Recurrent Memory with Optimal Polynomial Projections Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré (NeurIPS 2020 Spotlight) 今日はこれらを解説していく!! Combining Recurrent, Convolutional, and Continuous-time Models with the Linear S4を分かるには State Space Layer Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré これ全部理解するしかない (NeurIPS 2021) Efficiently Modeling Long Sequences with Structured State Spaces First author全部同じ人や。。。 Albert Gu, Karan Goel, Christopher Ré Albert Guさん強すぎ。。。 (ICLR 2022 Oral) Copyright © 2022 Morpho, Inc. All Rights Reserved 17
Overview Warning いよいよ本番です。 頑張って付いてきてください。 分からなくなったらすぐ声を上げてください。 Copyright © 2022 Morpho, Inc. All Rights Reserved 18
HiPPO HiPPO: Recurrent Memory with Optimal Polynomial Projections (NeurIPS 2020 Spotlight) 正直この論文が一番の肝 Copyright © 2022 Morpho, Inc. All Rights Reserved 19
HiPPO 概要 長時間の時系列データを扱う際はRNNやLSTMがよく用いられるが、 • 数万ステップにもなると記憶が抜けてしまう。 • シーケンス長や時間スケールへの暗黙の依存により、テスト時に汎化しない。 • 理論的な解釈が何となくしか与えられていない。 という問題がある。 Copyright © 2022 Morpho, Inc. All Rights Reserved 20
HiPPO 概要 そこで本論文では • 記憶に関する理論的な定式化を与え、既存手法たちをその枠組みで再解釈。 • シーケンス長や時間スケールに依存しない新手法を提案。 • 新手法の優位性を、提案した枠組みを用いて厳密に証明。 している。 Copyright © 2022 Morpho, Inc. All Rights Reserved 21
HiPPO 手法 そもそも「記憶を保持する」という言葉が既に曖昧。 そこでこれを次のように言い換えるところからスタートする。 「記憶」 = 「時間依存する測度」に基づく「入力信号の多項式近似」 2 1 Copyright © 2022 Morpho, Inc. All Rights Reserved 22
HiPPO ①入力信号の多項式近似 ℝ上の入力信号𝑓: ℝ → ℝを考えよう。 目標は 𝑓 𝑥 𝑥 ≤ 𝑡 までが与えられるので、これから𝑓(𝑥)を推定することである。 𝑓(𝑥) 0 t 𝑥 Copyright © 2022 Morpho, Inc. All Rights Reserved 23
HiPPO ①入力信号の多項式近似 {𝑓(𝑥)|𝑥 ≤ 𝑡}の値をすべて使って推定するのが最善だが、我々はこれをモデルに 学習させたいので、すべての値をモデルに記憶させるのはメモリ的に厳しい。 そこで{𝑓(𝑥)|𝑥 ≤ 𝑡}を何らかの低次元表現に保存することを考える。 𝑓(𝑥) 0 𝑡 𝑥 Copyright © 2022 Morpho, Inc. All Rights Reserved 24
HiPPO ①入力信号の多項式近似 ここでは「直交多項式系」を低次元表現として採用する。 Def. 測度𝜇(𝑥)に対する直交多項式系とは多項式の集合 𝑃𝑛 𝑛=0,1,2,… であって deg 𝑃𝑛 = 𝑛, ⟨𝑃𝑛 , 𝑃𝑚 ⟩𝜇 ≔ න 𝑃𝑛 𝑥 𝑃𝑚 𝑥 𝑑𝜇 = 𝑎𝑛,𝑚 𝛿𝑛,𝑚 (∃𝑎𝑛,𝑚 ∈ ℝ) を満たすものを言う。 Rem. ℝ上の測度が与えられれば、直交多項式はスケール𝑎𝑛,𝑚 を除き一意的。 これは{1, 𝑥, 𝑥 2 , … }をグラムシュミット直交化すれば示せる。 Copyright © 2022 Morpho, Inc. All Rights Reserved 25
HiPPO ①入力信号の多項式近似 Ex. ルジャンドル多項式 は測度𝜇 𝑥 ≔ 1 −1,1 1 𝑑𝑛 2 − 1 𝑛] 𝑃𝑛 𝑥 ≔ 𝑛 [ 𝑥 2 𝑛! 𝑑𝑥 𝑛 𝑥 に対する直交多項式。 2 直交性はググれば証明出てくる。 𝑃𝑛 , 𝑃𝑚 = 𝛿 2𝑛+1 𝑛,𝑚 なおルジャンドル多項式特有の性質として次がある(後で使う) • 𝑃𝑛 1 = 1, 𝑃𝑛 −1 = −1 𝑛 • 𝑃𝑛′ = 2𝑛 − 1 𝑃𝑛−1 + 2𝑛 − 3 𝑃𝑛−2 + ⋯ • 𝑥 + 1 𝑃𝑛′ 𝑥 = 𝑛𝑃𝑛 + 2𝑛 − 1 𝑃𝑛−1 + 2𝑛 − 3 𝑃𝑛−2 + ⋯ Copyright © 2022 Morpho, Inc. All Rights Reserved 26
HiPPO ①入力信号の多項式近似 この時𝑓[𝑥≤𝑡} を直交多項式 𝑃𝑛 𝑛=0,1,…,𝑁−1 で近似することを考えよう。 直交多項式の次数は適当な𝑁未満で打ち切っていることに注意。 Fact ([2, Theorem3.10, Theorem3.5]) 1. 測度𝜇から定まる直交多項式系 𝑃𝑛 𝑛=0,1,… を固定した時、 任意の関数𝑓 ∈ 𝐿2 (ℝ; 𝜇)は以下の級数展開を持つ。 𝑓 = 𝑐𝑛 𝑃𝑛 , 𝑐𝑛 ≔ 𝑓, 𝑃𝑛 / 𝑃𝑛 , 𝑃𝑛 𝑛=0,1,… 2. 上記級数を𝑛 = 𝑁 − 1で打ち切ったものを𝑓 (𝑛) としたとき、 𝑓 (𝑛) = 𝑎𝑟𝑔𝑚𝑖𝑛𝑔∈𝑆𝑝𝑎𝑛⟨𝑃0 ,…,𝑃𝑁−1 ⟩ 𝑓 − 𝑔 𝜇 Copyright © 2022 Morpho, Inc. All Rights Reserved 27
HiPPO ①入力信号の多項式近似 気持ちとしては、「𝑓をN次未満直交多項式の張る空間に射影」している。 𝑓{𝑥≤𝑡} これにより、話を元に戻すと 𝑆𝑝𝑎𝑛⟨𝑃0 , … , 𝑃𝑁−1 ⟩ 𝑓{𝑥≤𝑡} ≈ 𝑐0 𝑃𝑜 + 𝑐1 𝑃1 + ⋯ + 𝑐𝑁−1 𝑃𝑁−1 として過去の信号履歴を(𝑐0 , 𝑐1 , … , 𝑐𝑁−1 )の𝑁変数に圧縮することができた。 さらに上式を使って未来の信号を予測することも可能! Copyright © 2022 Morpho, Inc. All Rights Reserved 28
HiPPO ②時間依存する測度 ではどんな測度(直交多項式)を選ぶのが最適だろうか? • 入力信号𝑓は時間が進むごとに履歴がどんどん積み重なるので、 測度𝜇も時間発展させた方がよいだろう。 そこで以後、時刻𝑡における測度を𝜇 𝑡 と記述する。 注意: 特に以降𝜇 𝑡 のsupportは(−∞, 𝑡]に含まれるものとする。 Copyright © 2022 Morpho, Inc. All Rights Reserved 29
HiPPO ここまでのまとめ • 「記憶」=「入力信号𝑓の過去履歴を𝑁次元直交多項式系{𝑃𝑛 }に射影」 • 直交多項式系 𝑃𝑛 は測度𝜇 𝑡 に応じて時間発展させる。 Def. 入力信号𝑓の履歴を直交多項式系に射影して、その係数を取得する操作を ℎ𝑖𝑝𝑝𝑜 𝑓{𝑥≤𝑡} ≔ 𝑐0 , 𝑐1 , … , 𝑐𝑁−1 𝑤ℎ𝑒𝑟𝑒 𝑓{𝑥≤𝑡} ≈ 𝑐0 𝑃𝑜 + 𝑐1 𝑃1 + ⋯ + 𝑐𝑁−1 𝑃𝑁−1 と書き、HiPPO operatorと呼ぶ。(HiPPO=high-order Polynomial Projection Operator) Copyright © 2022 Morpho, Inc. All Rights Reserved 30
HiPPO 「記憶」=(𝑐0 , 𝑐1 , … , 𝑐𝑁−1 )なことは分かったので、次は 「記憶のアップデート」= (𝑐0 , 𝑐1 , … , 𝑐𝑁−1 )の時間発展 がどうなっているか導出したい。 実は驚くべき結論が成り立つ。 Theorem ([3, Appendix C]) 古典的な直交関数系に対して、𝑐(𝑡)の時間発展はlinear ODEで記述できる: 𝑐′ 𝑡 = 𝐴 𝑡 𝑐 𝑡 + 𝐵 𝑡 𝑓 𝑡 , (∃𝐴 𝑡 ∈ ℝ𝑁∗𝑁 , ∃𝐵 𝑡 ∈ ℝ𝑁∗1 ) Copyright © 2022 Morpho, Inc. All Rights Reserved 31
HiPPO 冒頭で線形方程式が出てきた理由はまさにこれ。 以下これを証明し、具体例を与える。 Notations (𝑡) • 𝑓, 𝜇 (𝑡) , 𝑃𝑛 :入力信号、時刻tでの測度、付随する直交多項式 • 𝑑𝜇 (𝑡) = 𝜔 (𝑡) 𝑥 𝑑𝑥、また𝜇 (𝑡) は確率測度であると仮定(i.e. = )𝑡( 𝜇𝑑 1) • (𝑡) 𝑝𝑛 𝑥 𝑡 𝑡 𝑡 = 𝑃𝑛 (𝑥)/⟨𝑃𝑛 , 𝑃𝑛 ⟩(正規化) Copyright © 2022 Morpho, Inc. All Rights Reserved 32
HiPPO (証明) まず係数𝑐𝑛 (𝑡)の構成を思い出せば (𝑡) 𝑐𝑛 𝑡 = 𝑓≤𝑡 , 𝑃𝑛 (𝑡) (𝑡) / 𝑃𝑛 , 𝑃𝑛 𝑡 = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛′ (𝑡) = න 𝑓 ∗ 𝜕 𝑡 𝑝𝑛 𝜕𝑡 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝜕 𝑡 ∗ 𝜔 𝜕𝑡 𝑑𝑥 Copyright © 2022 Morpho, Inc. All Rights Reserved 33
HiPPO (証明) まず係数𝑐𝑛 (𝑡)の構成を思い出せば (𝑡) 𝑐𝑛 𝑡 = 𝑓≤𝑡 , 𝑃𝑛 (𝑡) (𝑡) / 𝑃𝑛 , 𝑃𝑛 𝑡 = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛′ (𝑡) = න 𝑓 ∗ 𝜕 𝑡 𝑝𝑛 𝜕𝑡 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝜕 𝑡 ∗ 𝜔 𝜕𝑡 𝑑𝑥 𝜕 𝑡 𝑝𝑛 は𝑥についての𝑛次多項式 𝜕𝑡 (𝑡) (𝑡) であるから𝑝0 , … , 𝑝𝑛 の線形和。 Copyright © 2022 Morpho, Inc. All Rights Reserved 34
HiPPO (証明) まず係数𝑐𝑛 (𝑡)の構成を思い出せば (𝑡) 𝑐𝑛 𝑡 = 𝑓≤𝑡 , 𝑃𝑛 (𝑡) (𝑡) / 𝑃𝑛 , 𝑃𝑛 𝑡 = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛′ (𝑡) = න 𝑓 ∗ 𝜕 𝑡 𝑝𝑛 𝜕𝑡 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝜕 𝑡 𝑝𝑛 は𝑥についての𝑛次多項式 𝜕𝑡 (𝑡) (𝑡) であるから𝑝0 , … , 𝑝𝑛 の線形和。 𝜕 𝑡 ∗ 𝜔 𝜕𝑡 𝑑𝑥 𝜕 𝜔 (𝑡) は古典的な直交関数系では 𝜕𝑡 𝜔 (𝑡) とディラック𝛿𝑡 の線形和。 Copyright © 2022 Morpho, Inc. All Rights Reserved 35
HiPPO (証明) まず係数𝑐𝑛 (𝑡)の構成を思い出せば (𝑡) 𝑐𝑛 𝑡 = 𝑓≤𝑡 , 𝑃𝑛 (𝑡) (𝑡) / 𝑃𝑛 , 𝑃𝑛 𝑡 = න 𝑓 𝑥 ∗ 𝑝𝑛 𝑥 ∗ 𝜔 𝑡 (𝑥)𝑑𝑥 両辺を𝑡で微分すると(積分と微分の可換性は認める) 𝑐𝑛′ (𝑡) = න 𝑓 ∗ 𝜕 𝑡 𝑝𝑛 𝜕𝑡 𝑡 ∗ 𝜔 𝑡 𝑑𝑥 + න 𝑓 ∗ 𝑝𝑛 𝜕 𝑡 𝑝𝑛 は𝑥についての𝑛次多項式 𝜕𝑡 (𝑡) (𝑡) であるから𝑝0 , … , 𝑝𝑛 の線形和。 𝜕 𝑡 ∗ 𝜔 𝜕𝑡 𝑑𝑥 𝜕 𝜔 (𝑡) は古典的な直交関数系では 𝜕𝑡 𝜔 (𝑡) とディラック𝛿𝑡 の線形和。 これより第一項は𝑐0 , … , 𝑐𝑛 の線形和、第二項は𝑐𝑛 と𝑓(𝑡)の線形和。(証明終) Copyright © 2022 Morpho, Inc. All Rights Reserved 36
HiPPO 実際にルジャンドル関数系を用いて実証してみる。 (図2つは[3]より引用) ただしルジャンドル関数系は[−1,1]上の関数系なので、 存在域を𝑡依存になるようスケールしてから適用する。 パターン1: [𝑡 − 𝜃, 𝑡]上に定義(𝜃 ≥ 0は何時刻前までを見るかを表すハイパラ) パターン2: [0, 𝑡]上に定義(過去の履歴をすべて見る) それぞれの場合で𝑐𝑛′ (𝑡)がどう書けるか見てみよう。 Copyright © 2022 Morpho, Inc. All Rights Reserved 37
HiPPO (図は[3]より引用) パターン1: [𝑡 − 𝜃, 𝑡]上に定義 このときの正規直交関数系は、ルジャンドル関数系 𝑃𝑛 𝑥 を用いて 1 2 𝑥−𝑡 𝑡 2 𝑝𝑛 𝑥 ≔ 2𝑛 + 1 𝑃𝑛 +1 𝜃 12 𝜕 𝑡 (𝑡) (𝑡) 𝑝𝑛 = − 2𝑛 + 1 2 2𝑛 − 1 1/2 𝑝𝑛−1 + 2𝑛 − 5 1/2 𝑝𝑛−3 + ⋯ 𝜕𝑡 𝜃 またこのとき 𝜔𝑡 1 𝑥 = 1 𝑡−𝜃,𝑡 = 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 − 𝜃 𝜃 𝜕 𝑡 𝜔 = −𝛿 𝑥 − 𝑡 − 𝜃 𝜕𝑡 − 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 + 𝛿 𝑥 − 𝑡 = 𝛿𝑡 − 𝛿𝑡−𝜃 Copyright © 2022 Morpho, Inc. All Rights Reserved 38
HiPPO (図は[3]より引用) パターン1: [𝑡 − 𝜃, 𝑡]上に定義 これを先ほどの式に代入すると 𝜕 𝑡 第一項 = න 𝑓 ∗ 𝑝 𝜕𝑡 𝑛 第二項 𝑡 = න 𝑓 ∗ 𝑝𝑛 12 𝑡 ∗ 𝜔 𝑑𝑥 = − 2𝑛 + 1 2 𝜕 𝑡 ∗ 𝜔 𝜕𝑡 𝜃 1 2𝑛 − 1 2 𝑐𝑛−1 + 𝑡 1 2𝑛 − 5 2 𝑐𝑛−3 + ⋯ 𝑡 𝑑𝑥 = 𝑓 𝑡 𝑝𝑛 𝑡 − 𝑓 𝑡 − 𝜃 𝑝𝑛 𝑡 − 𝜃 辺々加えてこねこねすれば、 1 1 𝑐′ 𝑡 = − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 𝜃 𝜃 𝐴𝑛𝑘 = 1 1 2𝑛 + 1 2 2𝑘 + 1 2 ൝ 1 𝑖𝑓 𝑘 ≤ 𝑛 , 𝑛−𝑘 −1 𝑖𝑓 𝑘 ≥ 𝑛 𝐵𝑛 = 1 2𝑛 + 1 2 Copyright © 2022 Morpho, Inc. All Rights Reserved 39
HiPPO (図は[3]より引用) パターン1: [𝑡 − 𝜃, 𝑡]上に定義 1 Def. 測度 1 𝑡−𝜃,𝑡 から導出されるHiPPOの時間発展式 𝜃 1 1 𝑐′ 𝑡 = − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 𝜃 𝜃 ただし 1 1 1 𝑖𝑓 𝑘 ≤ 𝑛 𝐴𝑛𝑘 = 2𝑛 + 1 2 2𝑘 + 1 2 ൝ , 𝑛−𝑘 −1 𝑖𝑓 𝑘 ≥ 𝑛 𝐵𝑛 = 1 2𝑛 + 1 2 をHiPPO-LegTと呼ぶ。(translated Legendre) 実はこのODEは少し式変形すると[4]の論文で提案された式と一致する。 しかし[4]の論文ではPadé approximationという別手法を用いて導出。 Copyright © 2022 Morpho, Inc. All Rights Reserved 40
HiPPO (図は[3]より引用) パターン2: [0, 𝑡]上に定義 このときの正規直交関数系は、ルジャンドル関数系 𝑃𝑛 𝑥 を用いて 𝑡 𝑝𝑛 𝑥 ≔ 2𝑛 + 1 1/2 𝑃𝑛 2𝑥 −1 𝑡 𝜕 𝑡 1 (𝑡) (𝑡) (𝑡) 𝑝𝑛 = − 2𝑛 + 1 1/2 𝑛 2𝑛 + 1 −1/2 𝑝𝑛 + 2𝑛 − 1 1/2 𝑝𝑛−1 + 2𝑛 − 3 1/2 𝑝𝑛−2 + ⋯ 𝜕𝑡 𝑡 またこのとき 1 1 𝑥 = 1 0,𝑡 = 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝐻𝑒𝑣𝑖𝑠𝑖𝑑𝑒 𝑥 − 𝑡 𝑡 𝑡 𝜕 𝑡 1 1 1 𝜔 = − 2 1 0,𝑡 + 𝛿 𝑥 − 𝑡 = (−𝜔 (𝑡) + 𝛿𝑡 ) 𝜕𝑡 𝑡 𝑡 𝑡 𝜔𝑡 Copyright © 2022 Morpho, Inc. All Rights Reserved 41
HiPPO (図は[3]より引用) パターン2: [0, 𝑡]上に定義 これを先ほどの式に代入すると 𝜕 𝑡 第一項 = න 𝑓 ∗ 𝑝 𝜕𝑡 𝑛 𝑡 第二項 = න 𝑓 ∗ 𝑝𝑛 ∗ 11 𝑡 ∗ 𝜔 𝑑𝑥 = − 2𝑛 + 1 2 𝑡 𝜕 𝑡 𝜔 𝜕𝑡 𝑑𝑥 = − 1 − 𝑛 2𝑛 + 1 2 𝑐𝑛 + 2𝑛 − 1 1/2 𝑐𝑛−1 + 2𝑛 − 3 1/2 𝑐𝑛−2 + ⋯ 1 𝑡 𝑐𝑛 𝑡 + 𝑓 𝑡 𝑝𝑛 (𝑡) 𝑡 辺々加えてこねこねすれば、 1 1 𝑐′ 𝑡 = − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 𝑡 𝑡 𝐴𝑛𝑘 = ൞ 1 1 2𝑛 + 1 2 2𝑘 + 1 2 (𝑛 > 𝑘) , 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) 𝐵𝑛 = 1 2𝑛 + 1 2 Copyright © 2022 Morpho, Inc. All Rights Reserved 42
HiPPO (図は[3]より引用) パターン2: [0, 𝑡]上に定義 1 Def. 測度 1 0,𝑡 から導出されるHiPPOの時間発展式 𝑡 1 1 𝑐′ 𝑡 = − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 𝑡 𝑡 ただし 1 1 2𝑛 + 1 2 2𝑘 + 1 2 (𝑛 > 𝑘) 𝐴𝑛𝑘 = ൞ , 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) をHiPPO-LegSと呼ぶ。(scaled Legendre) 𝐵𝑛 = 1 2𝑛 + 1 2 実はこれがまさに本論文で提案する新手法に他ならない! Copyright © 2022 Morpho, Inc. All Rights Reserved 43
HiPPO ラゲール、チェビシェフ、エルミート等、他の直交関数系に対しても導出が可能。 HiPPO-LegSは[0,t]のすべての時刻を見る点で直感的にHiPPO-LegTよりも優 れているが、以降でこの式が多くの嬉しい性質を満たすことを見る。 ここまでのまとめ • ℎ𝑖𝑝𝑝𝑜の出力(𝑐0 , 𝑐1 , … , 𝑐𝑁−1 )の時間変化はlinear ODEで書ける。 • ODEの係数行列は陽に書けて実際に計算可能。 • HiPPOの枠組みで既存手法を導出可能(HiPPO-LegT)。 • HiPPO-LegSという新しい時間発展式を提案。 Copyright © 2022 Morpho, Inc. All Rights Reserved 44
HiPPO 最後にHiPPO-LegSの持つ良い性質を見ていこう。 スペースの都合上、ここでは時間スケールに依存しないことだけ見る。 その他の性質は最後に結果のみを列挙する。 Copyright © 2022 Morpho, Inc. All Rights Reserved 45
HiPPO Lemma ([3, Appendix B]) HiPPO-LegSは時間スケールに依存しない。 (証明) 前述のODEを計算するにあたり、まずは離散化をしないといけない。 𝑐′ 𝑡 の両辺を積分して、 1 1 = − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 𝑡 𝑡 𝑐 𝑡 + Δ𝑡 − 𝑐 𝑡 𝑡+Δ𝑡 1 1 =න − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 𝑑𝑡 𝑡 𝑡 𝑡 Δ𝑡 1 1 1 1 ≈ − 𝐴𝑐 𝑡 + 𝐵𝑓 𝑡 + − 𝐴𝑐 𝑡 + Δ𝑡 + 𝐵𝑓 𝑡 + Δ𝑡 2 𝑡 𝑡 𝑡 + Δ𝑡 𝑡 + Δ𝑡 Copyright © 2022 Morpho, Inc. All Rights Reserved 46
HiPPO 辺々整理すると Δ𝑡 Δ𝑡 Δ𝑡 Δ𝑡 𝐼+ 𝐴 𝑐 𝑡 + Δ𝑡 = 𝐼 − 𝐴 𝑐 𝑡 + + 𝐵𝑓(𝑡) 2 𝑡 + Δ𝑡 2𝑡 2 𝑡 + Δ𝑡 2𝑡 なお𝑓 𝑡 + Δ𝑡 = 𝑓(𝑡)の仮定を暗黙に使った。 ここで𝑡 = 𝑘Δ𝑡, 𝑐𝑘 ≔ 𝑐 𝑘Δ𝑡 , 𝑓𝑘 ≔ 𝑓(𝑘Δ𝑡)とすれば、 1 1 1 1 𝐼+ 𝐴 𝑐𝑘+1 = 𝐼 − 𝐴 𝑐𝑘 + + 𝐵𝑓(𝑡) 2(𝑘 + 1) 2𝑘 2 𝑘+1 2𝑘 ⇒どこにもΔ𝑡が出てこない!(証明終わり) (HiPPO-LegTなど他の直交関数系だとこうはならない) Copyright © 2022 Morpho, Inc. All Rights Reserved 47
HiPPO 上の結果と合わせて、他の性質もまとめて理論説明を終わる。 ここまでのまとめ • HiPPO-LegSは時間スケールに依存しない。(ドメインシフトに強い) • HiPPOの1回のdiscretized ODE計算はO(N)。 • • 𝜕𝑐𝑙+1 𝑘 ∈ 𝑁: fixedおよび∀𝑙 > 𝑘に対して = 𝑂(1/𝑙) (勾配消失・爆発しない!) 𝜕𝑓𝑘 𝑓𝑥≤𝑡 の𝑆𝑝𝑎𝑛⟨𝑃0 , … , 𝑃𝑁−1 ⟩への射影を𝑔(𝑡) としたとき 𝑡𝐿 • 𝑓が𝐿-Lipschitzなら 𝑓𝑥≤𝑡 − 𝑔(𝑡) = 𝑂 𝑁 • 𝑓の𝑘回微分が有界なら 𝑓𝑥≤𝑡 − 𝑔(𝑡) = 𝑂 𝑡 𝑘 𝑁 −𝑘+1/2 Copyright © 2022 Morpho, Inc. All Rights Reserved 48
HiPPO 実験 HiPPOの離散漸化式をRNNに組み込んで性能評価してみる。 hidden state ℎ𝑡 の履歴を記憶させるよう下図のモデル設計を採用。 ([3]より引用) Copyright © 2022 Morpho, Inc. All Rights Reserved 49
HiPPO 実験 タスク1: Permuted MNIST ([3]より引用) Copyright © 2022 Morpho, Inc. All Rights Reserved 50
HiPPO 実験 タスク2: Character Trajectory Classification ペン先の3次元速度情報から書いている文字を当てるタスク。 ([3]より引用) サンプリングレートを変えてドメインシフトを再現しているが、 HiPPO-LegSは影響を受けていない。 Copyright © 2022 Morpho, Inc. All Rights Reserved 51
HiPPO 実験 タスク3: Copying ([3]より引用) Copyright © 2022 Morpho, Inc. All Rights Reserved 52
LSSL Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (NeurIPS 2021) Copyright © 2022 Morpho, Inc. All Rights Reserved 53
LSSL 概要 HiPPOを以下のように改良する。 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) • ቊ の形に増強。 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) • 𝐴, 𝐵, 𝐶, 𝐷を学習パラメータに変更。 • 上記連立方程式がCNN/RNNの要素を含むことを証明。 Copyright © 2022 Morpho, Inc. All Rights Reserved 54
LSSL 手法 動機は論文に書いてないが、HiPPOの式を 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) ቊ 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) の形に増強する。 (状態空間モデルの方程式を意識してると思われる) しれっとA,B,C,Dは時間依存しないことになってる? HiPPO-LegSは係数行列は時間依存してたが。。。。。。 t → ∞ でAはほぼ変化しないので定数とみなしてるのかも。 Copyright © 2022 Morpho, Inc. All Rights Reserved 55
LSSL 手法 𝑦 𝑡 が𝑥 𝑡 と𝑢(𝑡)の線形和であることに注目する。 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) ቊ 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) これをLSSL(Linear State Space Layer)と呼ぶ。 Copyright © 2022 Morpho, Inc. All Rights Reserved 56
LSSL 手法 𝑦 𝑡 が𝑥 𝑡 と𝑢(𝑡)の線形和であることに注目する。 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) ቊ 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) これをLSSL(Linear State Space Layer)と呼ぶ。 以下この方程式が次の性質を持つことを順にみていこう。 1.線形であるために、RNNより高速に計算可能。 2.線形だと貧弱な気がするが、実は十分な表現力を持つ。 Copyright © 2022 Morpho, Inc. All Rights Reserved 57
LSSL 1.高速に計算可能 まずLSSLをbilinear離散化すると、特に第1式について積分して 𝑥’ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 Δ𝑡 𝑥 𝑡 + Δ𝑡 − 𝑥 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 + 𝐴𝑥 𝑡 + Δ𝑡 + 𝐵𝑢 𝑡 + Δ𝑡 2 ҧ 𝑡 + 𝐵𝑢 ത 𝑡 𝑥 𝑡 + Δ𝑡 = 𝐴𝑥 ただし Δ ҧ 𝐴= 𝐼− 𝐴 2 −1 Δ 𝐼+ 𝐴 , 2 Δ ത 𝐵= 𝐼− 𝐴 2 −1 Δ𝐵 Copyright © 2022 Morpho, Inc. All Rights Reserved 58
LSSL 1.高速に計算可能 この離散化式 ҧ 𝑘−1 + 𝐵𝑢 ത 𝑘 𝑥𝑘 = 𝐴𝑥 ൝ ҧ 𝑘 + 𝐷𝑢 ഥ 𝑘 𝑦𝑘 = 𝐶𝑥 から𝑥を削除すると、𝑥−1 = 0として ത 0 + 𝐷𝑢 ഥ 0 𝑦0 = 𝐶ҧ 𝐵𝑢 ത 0 + 𝐵𝑢 ത 1 +𝐷 ഥ 𝑢1 𝑦1 = 𝐶ҧ 𝐴ҧ𝐵𝑢 ത 0 + 𝐵𝑢 ത 1 + 𝐵𝑢 ത 2 +𝐷 ഥ 𝑢2 𝑦2 = 𝐶ҧ 𝐴ҧ 𝐴ҧ𝐵𝑢 ……… ത 0 + 𝐶ҧ 𝐴ҧ 𝑘−1 𝐵𝑢 ത 1 + ⋯ + 𝐶ҧ 𝐵𝑢 ത 𝑘 + 𝐷𝑢 ഥ 𝑘 𝑦𝑘 = 𝐶ҧ 𝐴ҧ 𝑘 𝐵𝑢 Copyright © 2022 Morpho, Inc. All Rights Reserved 59
LSSL 1.高速に計算可能 ഥ はお尻にしか付かないので𝐷 ഥ = 0として無視しよう。すると 𝐷 ത 0 + 𝐶ҧ 𝐴ҧ 𝑘−1 𝐵𝑢 ത 1 + ⋯ + 𝐶ҧ 𝐵𝑢 ത 𝑘 𝑦𝑘 = 𝐶ҧ 𝐴ҧ 𝑘 𝐵𝑢 ത 𝐶)ҧ ∗ 𝑢のconvolutionに他ならない。 となり、この式はまさに𝑦 = 𝐾𝐿 (𝐴,ҧ 𝐵, ത 𝐶ҧ ≔ (𝐶ҧ 𝐵, ത 𝐶ҧ 𝐴ҧ𝐵, ത … , 𝐶ҧ 𝐴ҧ𝐿−1 𝐵) ത 𝐾𝐿 𝐴,ҧ 𝐵, ここで𝐿はシーケンス長を表す。 これよりrecurrenceが不要になり、計算は高速。 Copyright © 2022 Morpho, Inc. All Rights Reserved 60
LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.1]) LSSLはbackward-Eulerで離散化した場合、 RNNのgating mechanismを包含する。 (証明) LSSLの第一式 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) をbackward-Eulerで離散化すると 𝑡+Δ𝑡 𝑥 𝑡 + Δ𝑡 − 𝑥 𝑡 = න 𝐴𝑥 𝑡 + 𝐵𝑢 𝑡 𝑑𝑡 ≈ Δ𝑡 𝐴𝑥 𝑡 + Δ𝑡 + 𝐵𝑢(𝑡 + Δ𝑡) 𝑡 Copyright © 2022 Morpho, Inc. All Rights Reserved 61
LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.1]) LSSLはbackward-Eulerで離散化した場合、 RNNのgating mechanismを包含する。 𝑥𝑘 ≔ 𝑥 𝑡 , 𝑥𝑘+1 ≔ 𝑥 𝑡 + Δ𝑡 , 𝑢𝑘+1 ≔ 𝑢(𝑡 + Δ𝑡)とし、さらにΔ𝑡 = 𝑒 𝑧 とおけば、 𝑥𝑘+1 − 𝑥𝑘 ≈ 𝑒 𝑧 𝐴𝑥𝑘+1 + 𝐵𝑢𝑘+1 𝐴𝑒 𝑧 𝐵𝑒 𝑧 𝑥𝑘+1 ≈ 1 − 𝑥𝑘 + 𝑢𝑘 𝑧 𝑧 1+𝑒 1+𝑒 ここで𝐴 = 𝐵 = 1とすれば、 𝑥𝑘+1 ≈ 1 − 𝜎 𝑧 𝑥𝑘 + 𝜎 𝑧 𝑢𝑘 となり、 これはgating mechanismに他ならない。(証明終わり) Copyright © 2022 Morpho, Inc. All Rights Reserved 62
LSSL 2.十分な表現力を持つ Lemma ([5, Lemma 3.2]) 𝑓(𝑡, 𝑥)がxについて局所Lipstizsである非線形関数としたとき、 無限にLSSLをstackしたモデルは𝑥’ 𝑡 = −𝑥 𝑡 + 𝑓(𝑡, 𝑥(𝑡))を解ける。 (証明概略) LSSLの線形部分をstackすると、それが実質ピカールの逐次近似 法を回していることになっている。 非線形部分𝑓はLSSL間にpointwise non-linearityな層を挟むことで再現す る。(証明終わり) ※この命題は本筋には使わない。詳細は各自論文参照。 Copyright © 2022 Morpho, Inc. All Rights Reserved 63
LSSL ここまでのまとめ • HiPPOにさらに線形方程式を追加したLSSLを提案。 • LSSLはconvolutionとして解釈可能なため高速。 • LSSLはRNNを含み、non-linear ODEを解くだけの能力を持つ。 Copyright © 2022 Morpho, Inc. All Rights Reserved 64
LSSL LSSLがHiPPOより真に優位であることは分かった。 次にこれを実際にどう学習に組み込むかを見ていく。 特に • Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい。 ത 𝐶ҧ ≔ (𝐶ҧ 𝐵, ത 𝐶ҧ 𝐴ҧ𝐵, ത … , 𝐶ҧ 𝐴ҧ𝐿−1 𝐵)を如何に高速計算するか。 ത • Convolution:𝐾𝐿 𝐴,ҧ 𝐵, を調べたい。 Copyright © 2022 Morpho, Inc. All Rights Reserved 65
LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? ഥ ≔ (𝑪 ഥ𝑩 ഥ𝑨 ഥ𝑨 ഥ, 𝑩 ഥ, 𝑪 ഥ, 𝑪 ഥ𝑩 ഥ, … , 𝑪 ഥ 𝑳−𝟏 𝑩 ഥ )を如何に高速計算するか Convolution:𝑲𝑳 𝑨 この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁 3 𝐿)かかる。 もっと速く計算できないだろうか? Copyright © 2022 Morpho, Inc. All Rights Reserved 66
LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい ここで残念なお知らせ AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? LSSLの論文でこの考察をしているが、 𝑳−𝟏 𝑩)を如何に高速計算するか Convolution:𝑲その結果はお世辞にもきれいとは言えない。 𝑳 𝑨, 𝑩, 𝑪 ≔ (𝑪𝑩, 𝑪𝑨𝑩, … , 𝑪𝑨 しかも計算は非常に不安定。 3 この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁 𝐿)かかる。 もっと速く計算できないだろうか? Copyright © 2022 Morpho, Inc. All Rights Reserved 67
LSSL Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい これらの問題点は AはHiPPOで導出されるような行列のクラスに限定したい。 一体それはどんな形で書けるのだろうか? S4の論文にて Convolution:𝑲𝑳 𝑨, 𝑩, 𝑪 ≔ (𝑪𝑩, 𝑪𝑨𝑩, … , 𝑪𝑨𝑳−𝟏 𝑩)を如何に高速計算するか 1年越しに解決! この式の中にはAのべき乗が大量に入っているので、愚直計算で𝑂(𝑁 3 𝐿)かかる。 もっと速く計算できないだろうか? Copyright © 2022 Morpho, Inc. All Rights Reserved 68
S4 Efficiently Modeling Long Sequences with Structured State Spaces (ICLR 2022 Oral) Copyright © 2022 Morpho, Inc. All Rights Reserved 69
S4 概要 LSSLで消化不良だった • Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい。 ത 𝐶ҧ ≔ (𝐶ҧ 𝐵, ത 𝐶ҧ 𝐴ҧ𝐵, ത … , 𝐶ҧ 𝐴ҧ𝐿−1 𝐵)を如何に高速計算するか。 ത • Convolution:𝐾𝐿 𝐴,ҧ 𝐵, を解決する。 Copyright © 2022 Morpho, Inc. All Rights Reserved 70
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい 一般的なHiPPO行列の形を導出するのは難しい。(LSSLの論文ではそれをやって大変汚いことに) そこで 「計算しやすさ」 と 「HiPPO-LegT/LegSを含む」 ことを条件に、学習する行列Aのクラスを決める。 Copyright © 2022 Morpho, Inc. All Rights Reserved 71
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Def. 行列𝐴 ∈ 𝑅𝑛∗𝑛 が 𝐴 = 𝐹 − 𝑝𝑞 𝑇 (𝐹: 𝑛𝑜𝑟𝑚𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗𝑘 𝑘 ≪ 𝑛 ) と書けるとき、𝐴はNPLR(Normal Plus Low-Rank)表現を持つという。 (Plusと言いつつマイナスにしているのは、本スライドでの説明の都合による) Fact 以下は同値: 1. 𝐹はnormal (i.e. 𝐹𝐹 ∗ = 𝐹 ∗ 𝐹) 2. 𝐹はユニタリ行列で対角化可能 Copyright © 2022 Morpho, Inc. All Rights Reserved 72
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 (証明) 以下ではHiPPO LegSのみ見ていく。このとき行列𝐴は 𝐴𝑛𝑘 = − ൞ 1 1 2𝑛 + 1 2 2𝑘 + 1 2 (𝑛 > 𝑘) 𝑛 + 1 (𝑛 = 𝑘) 0 (𝑛 < 𝑘) と書けた。 Copyright © 2022 Morpho, Inc. All Rights Reserved 73
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 ここで𝑝 = 2𝑛+1 2 1 2 とすると、 𝑛 𝐴 + 𝑝𝑝𝑇 𝑛𝑘 = − 1 1 1 𝑛𝑘 = 2 2𝑛 + 1 2 2𝑘 + 1 2 であり、 1 1 2𝑛 + 1 2 2𝑘 + 1 2 (𝑛 > 𝑘) 𝑝𝑝𝑇 1 2 ∗∗∗ 略 ∗∗∗ (𝑛 = 𝑘) 1 1 1 − 2𝑛 + 1 2 2𝑘 + 1 2 (𝑛 < 𝑘) 2 Copyright © 2022 Morpho, Inc. All Rights Reserved 74
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Theorem 1]) HiPPO LegT/LegSはNPLR表現を持つ。 すなわち 𝐴 + 𝑝𝑝𝑇 = 𝑠𝑘𝑒𝑤_𝑠𝑦𝑚𝑚𝑒𝑡𝑟𝑖𝑐 + 𝑘𝐼, ∃𝑘 ∈ ℝ の形になっており、特に右辺は正規行列。(証明終わり) Copyright © 2022 Morpho, Inc. All Rights Reserved 75
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい これを踏まえて行列AはNPLRの中で学習させることを考える。 が、実はさらにクラスを制限しても問題ないことを次に示す。 Def. 行列𝐴 ∈ 𝑅𝑛∗𝑛 が 𝐴 = Λ − 𝑝𝑞 𝑇 (Λ: 𝑑𝑖𝑎𝑔𝑜𝑛𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗𝑘 𝑘 ≪ 𝑛 ) と書けるとき、𝐴はDPLR(Diagonal Plus Low-Rank)表現を持つという。 Copyright © 2022 Morpho, Inc. All Rights Reserved 76
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Lemma 3.1]) HiPPO行列に共役な作用を施しても出力は不変。 (証明) 主張がやや不明瞭だが、証明を見れば意味が分かる。 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) ቊ 𝑦 𝑡 = 𝐶𝑥 𝑡 + 𝐷𝑢(𝑡) に対して 𝐴, 𝐵, 𝐶, 𝐷 → (𝑉 −1 𝐴𝑉, 𝑉 −1 𝐵, 𝐶𝑉, 𝐷)の変換を施すと、 𝑉𝑥 ′ 𝑡 = 𝐴𝑉𝑥 𝑡 + 𝐵𝑢(𝑡) 𝑥 ′ 𝑡 = 𝑉 −1 𝐴𝑉𝑥 𝑡 + 𝑉 −1 𝐵𝑢(𝑡) ൝ ↔ቊ 𝑦 𝑡 = 𝐶𝑉𝑥 𝑡 + 𝐷𝑢(𝑡) 𝑦 𝑡 = 𝐶𝑉𝑥 𝑡 + 𝐷𝑢(𝑡) Copyright © 2022 Morpho, Inc. All Rights Reserved 77
S4 Aを学習パラメータにするが、HiPPOの枠組みは外れないようにしたい Lemma ([6, Lemma 3.1]) HiPPO行列に共役な作用を施しても出力は不変。 すなわち𝑉の共役な作用が𝐴にかかっても、𝐵, 𝐶を適切に変換すれば、 作用の影響は潜在変数の変数変換にとどまる。(証明終わり) これより行列𝐴をDPLRの中で学習させるとしても問題ない。 Copyright © 2022 Morpho, Inc. All Rights Reserved 78
S4 ここまでのまとめ 𝐴 = Λ − 𝑝𝑞 𝑇 として、Λ, 𝑝, 𝑞を学習させることにする。 これにより求まる𝐴の属する空間は、 古典的な直交関数形に対するHiPPO行列たちを含む。 Copyright © 2022 Morpho, Inc. All Rights Reserved 79
S4 ത 𝐶ҧ ≔ (𝐶ҧ 𝐵, ത 𝐶ҧ 𝐴ҧ𝐵, ത … , 𝐶ҧ 𝐴ҧ𝐿−1 𝐵)の高速計算 ത 𝐾𝐿 𝐴,ҧ 𝐵, ここが本論文の山場。 なんと上記のconvolutionカーネル計算を、 ෨ + 𝐿)にまで落としてしまう。 愚直計算の𝑂(𝑁 3 𝐿)からなんと𝑂(𝑁 超絶技巧が盛りだくさんなので、step-by-stepに追っていこう。 Copyright © 2022 Morpho, Inc. All Rights Reserved 80
S4 STEP0. 先ほど述べたように、 𝐴 = Λ − 𝑝𝑞 𝑇 (Λ: 𝑑𝑖𝑎𝑔𝑜𝑛𝑎𝑙, 𝑝, 𝑞 ∈ ℝ𝑛∗1 ) と置く。説明簡単化のため、𝑝, 𝑞は𝑛 ∗ 1行列とする。(HiPPO-LegSはそう) また統一性のため 𝑥 ′ 𝑡 = 𝐴𝑥 𝑡 + 𝐵𝑢(𝑡) ቊ 𝑦 𝑡 = 𝐶 ∗ 𝑥 𝑡 + 𝐷𝑢(𝑡) のように𝐶を転置して、𝐶が𝐵, 𝑝, 𝑞と同じℝ𝑛∗1 の元であるようにする。 Copyright © 2022 Morpho, Inc. All Rights Reserved 81
S4 STEP1. ത 𝐶ҧ ∗ ≔ (𝐶ҧ ∗ 𝐵, ത 𝐶ҧ ∗ 𝐴ҧ𝐵, ത … , 𝐶ҧ ∗ 𝐴ҧ𝐿−1 𝐵) ത 𝐾𝐿 𝐴,ҧ 𝐵, を直接求めるのではなく、それのz変換もどき 𝐿−1 𝐿 𝑧; 𝐴,ҧ 𝐵, ത 𝐶ҧ ∗ ≔ 𝐶ҧ ∗ 𝐴ҧ𝑖 𝐵𝑧 ത 𝑖 ∈ ℂ[𝑧] 𝐾 𝑖=0 を求めることを考える。 𝐿 から𝐾𝐿 を導出するのは、zに1のべき根を突っ込んでiFFTにより𝑂(𝐿 log 𝐿) 𝐾 Copyright © 2022 Morpho, Inc. All Rights Reserved 82
S4 STEP2. Lemma ([6, Lemma C.3]) 2 ∗ ҧ ҧ 𝐿 𝑧; 𝐴, 𝐵, ത 𝐶 = 𝐾 𝐶ሚ ∗ 𝑅 𝑧 𝐵 − 𝐶ሚ ∗ 𝑅 𝑧 𝑝 1 + 𝑞 ∗ 𝑅 𝑧 𝑝 −1 𝑞 ∗ 𝑅 𝑧 𝐵 1+𝑧 ただし 𝐶ሚ = 𝐶 𝐼 − 𝐴ҧ𝐿 , 21−𝑧 𝑅 𝑧; Λ = −Λ Δ1 + 𝑧 −1 (証明) 形式べき級数を用いて、mod 𝑧 𝐿 の下で 𝐿−1 ҧ −1 𝐵ത = 𝐶ሚ ∗ 𝐼 − 𝐴𝑧 ҧ −1 𝐵ത 𝐿 𝑧; 𝐴,ҧ 𝐵, ത 𝐶ҧ ∗ ≔ 𝐶ҧ ∗ 𝐴ҧ𝑖 𝐵𝑧 ത 𝑖 = 𝐶ҧ ∗ 𝐼 − 𝐴ҧ𝐿 𝐼 − 𝐴𝑧 𝐾 𝑖=0 Copyright © 2022 Morpho, Inc. All Rights Reserved 83
S4 また、LSSLの離散化手続きを思い出すと Δ ҧ 𝐴= 𝐼− 𝐴 2 −1 Δ 𝐼+ 𝐴 , 2 Δ ത 𝐵= 𝐼− 𝐴 2 −1 Δ𝐵 であるが、これを前述の式に代入すると以下を得る(詳細は[6, Lemma C.4]) ҧ −1 𝐵ത = 𝐶ሚ ∗ 𝐼 − 𝐴𝑧 2Δ ∗ 1 − 𝑧 𝐶ሚ 2 𝐼 − Δ𝐴 1+𝑧 1+𝑧 −1 𝐵 Copyright © 2022 Morpho, Inc. All Rights Reserved 84
S4 ここでさらに𝐴 = Λ − 𝑝𝑞 𝑇 なことを思い出すと −1 2Δ ∗ 1 − 𝑧 𝐶ሚ 2 𝐼 − Δ Λ − 𝑝𝑞 ∗ 𝐵 1+𝑧 1+𝑧 diagonal −1 2 2 1 − 𝑧 = 𝐶ሚ ∗ 𝐼 − Λ + 𝑝𝑞 ∗ 𝐵 1+𝑧 Δ1 + 𝑧 ҧ −1 𝐵ത = 𝐶ሚ ∗ 𝐼 − 𝐴𝑧 2 −1 ∗ ∗ ∗ ∗𝑅 𝑧 𝑝 ሚ ሚ = 𝐶 𝑅 𝑧 𝐵−𝐶 𝑅 𝑧 𝑝 1+𝑞 𝑞 𝑅 𝑧 𝐵 1+𝑧 なお最後の等号はWoodbury Identityから従う。(証明終わり) Fact (Woodbury Identity) 任意の行列𝐴, 𝑃, 𝑄に対して以下が成り立つ 𝐴 + 𝑈𝑉 ∗ −1 = 𝐴−1 − 𝐴−1 𝑈 𝐼 + 𝑉 ∗ 𝐴−1 𝑈 −1 𝑉 ∗ 𝐴−1 Copyright © 2022 Morpho, Inc. All Rights Reserved 85
S4 STEP2. 求めた式は一見煩雑になっただけに見えるが、よく見ると赤線部分 𝐿 𝑧; 𝐴,ҧ 𝐵, ത 𝐶ҧ ∗ = 𝐾 2 𝐶ሚ ∗ 𝑅 𝑧 𝐵 − 𝐶ሚ ∗ 𝑅 𝑧 𝑝 1 + 𝑞 ∗ 𝑅 𝑧 𝑝 −1 𝑞 ∗ 𝑅 𝑧 𝐵 1+𝑧 はすべてスカラーであり、かつ𝑅 𝑧; Λ = −1 2 1−𝑧 − Λ は対角行列。 Δ 1+𝑧 すなわち上の計算は登場する行列たちが既知なら𝑂(𝑁)で求まる。 Copyright © 2022 Morpho, Inc. All Rights Reserved 86
S4 STEP3. よってあとは新規の登場人物たち、とくに 𝐶ሚ = 𝐶 𝐼 − 𝐴ҧ𝐿 , 21−𝑧 𝑅 𝑧; Λ = −Λ Δ1 + 𝑧 −1 の2つが高速に求められれば良い。 ሚ 前者は発想の転換で、𝐶ではなく𝐶を最初から学習させることにすれば解決。 Copyright © 2022 Morpho, Inc. All Rights Reserved 87
S4 STEP3. 21−𝑧 𝑅 𝑧; Λ = −Λ Δ1 + 𝑧 −1 だが、一見対角行列なので𝑂(𝑁)で計算可能で、何も問題ないように見える。 しかしSTEP1を見直すと、我々は 𝐿 𝑧; 𝐴,ҧ 𝐵, ത 𝐶ҧ ∗ 𝐾 をすべての1の𝐿乗根に対して求める必要がある。 よってこのままだと𝑂(𝑁𝐿)かかってしまいまずい。 Copyright © 2022 Morpho, Inc. All Rights Reserved 88
S4 STEP3. どのみち𝑅(𝑧)を𝑂(𝑁)より早く求めたとしても、STEP2の計算をすべての1の𝐿 乗根に対して行うと𝑂(𝑁𝐿)かかってしまう。 そこで少し視点を変えて、一般に赤線部分 𝑉 ∗ 𝑅 𝑧 𝑈, (∀𝑈, 𝑉 ∈ ℝ𝑛∗1 ) ෨ + 𝐿)で求めることを考える。 をすべての𝑧 ∈ {1の𝐿乗根}に対して一括で𝑂(𝑁 実は𝑅(𝑧)の特殊構造により、これが可能である。 Copyright © 2022 Morpho, Inc. All Rights Reserved 89
S4 STEP3. Def. K ∈ ℝ𝑀∗𝑁 であって、 1 𝐾𝑖𝑗 = , (𝜔𝑖 , 𝜆𝑗 ∈ ℂ) 𝜔𝑖 − 𝜆𝑗 と書けるものをCauchy Kernelと呼ぶ。 Fact [7] Cauchy Kernelの行列ベクトル積にかかる計算量は ൞ 𝑂 𝑂 𝑀 + 𝑁 log 2 𝑀 + 𝑁 , 𝑒𝑥𝑎𝑐𝑡 𝑎𝑟𝑖𝑡ℎ𝑚𝑒𝑡𝑖𝑐 1 𝑀 + 𝑁 log 𝑀 + 𝑁 log , 𝑛𝑢𝑚𝑒𝑟𝑖𝑐𝑎𝑙𝑙𝑦 𝑡𝑜 𝑝𝑟𝑒𝑐𝑖𝑠𝑖𝑜𝑛 𝜖 𝜖 Copyright © 2022 Morpho, Inc. All Rights Reserved 90
S4 STEP3. これを踏まえると 21−𝑧 𝑅 𝑧; Λ = −Λ Δ1 + 𝑧 −1 はまさにCauchy Kernelに他ならない。 ෨ + 𝐿 𝑧; 𝐴,ҧ 𝐵, ത 𝐶ҧ ∗ の計算は一括で𝑂(𝑁 ゆえにすべての𝑧 ∈ {1の𝐿乗根}に対して、 𝐾 𝐿)で終わる。 ෨ + 𝐿)で計算が完了する。 STEP1のiFFTは𝑂(𝐿 log 𝐿)なので、全体としても𝑂(𝑁 (おわり!) Copyright © 2022 Morpho, Inc. All Rights Reserved 91
まとめ まとめ • 時系列モデリングの新手法HiPPOを提案。 • HiPPOを状態空間モデルの方程式に組み込み、高速なconvolution計算を実現。 • Path-Xタスクで世界初の推論成功を達成。 「所感」 • 概念基盤がかなりしっかりしていて、かつ汎用性が高い。 • 後続研究にS4をaudio generationやvideo classificationに使用した例あり。 「おまけ」 公式実装:https://github.com/HazyResearch/state-spaces 解説付きJax実装:https://srush.github.io/annotated-s4 Copyright © 2022 Morpho, Inc. All Rights Reserved. 92
まとめ 参考文献 [1] Yi Tay, Mostafa Dehghani, Samira Abnar, Yikang Shen, Dara Bahri, Philip Pham, Jinfeng Rao, Liu Yang, Sebastian Ruder, and Donald Metzler. Long range arena : A benchmark for efficient transformers. In International Conference on Learning Representations, 2021. [2] 黒田成俊. 関数解析. 共立出版. 1980. [3] Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher R´e. Hippo: Recurrent memory with optimal polynomial projections. In Advances in Neural Information Processing Systems, pages 1474-1487, 2020. [4] Aaron Voelker, Ivana Kajić, and Chris Eliasmith. Legendre memory units: Continuoustime representation in recurrent neural networks. In Advances in Neural Information Processing Systems, pages 15544–15553, 2019. [5] Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, and Christopher R´e. Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer. In Advances in Neural Information Processing Systems, pages 572-585, 2021. Copyright © 2022 Morpho, Inc. All Rights Reserved. 93
まとめ 参考文献 [6] Albert Gu, Karan Goel, and Christopher R´e. Efficiently modeling long sequences with structured state spaces. In International Conference on Learning Representations, 2022. [7] Victor Pan. Structured matrices and polynomials: unified superfast algorithms. Springer Science & Business Media, 2001. Copyright © 2022 Morpho, Inc. All Rights Reserved. 94
Thank you