4.5K Views
May 16, 24
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs Gouki Minegishi, Matsuo Lab, M2 http://deeplearning.jp/
書誌情報 選定理由 • ICLR2024のspotlight (10,8,8,5) • Posterで見たけど内容が難しそうでわからなかったから • Naomi Saphraの研究が気になってたから 2
背景 • 一般的な言語モデルのLossはスムーズに下がる一方で能力の発現は一様ではない • このような急激な能力の向上は, breakthroughs, emergence, breaks, phase transitionと呼ばれている • しかし解釈性における先行研究では学習済みモデルを分析することが多く, 学習の過程を分析しているものは少ない • この論文では特にSASと呼ばれる文法構造に関連したモデルの内部構造 についての学習ダイナミクスを分析する In-context learning and Induction head 3
Syntactic Attention Structure (SAS) • 言語は木構造 – 文の各単語の修飾-被修飾関係を考えると(親)builds→(子)nests, (親)nests→(子)ugly • 特定の単語間の依存関係に注目したattentionのheadが存在する[1,2,3] – 例えばあるheadは(親)builds→(子)nestsの関係を捉えているなど • このようなheadの形成をSyntactic Attention Structure(構文的注意構造)と呼ぶ 4
Unlabeled Attachment Score (UAS) SASを定量的に測る指標を提案(UAS) 1. ある単語𝑥𝑖 に関してattention scoreが最大の単語𝑥𝑗 2. そのような(𝑥𝑖 , 𝑥𝑗 )が特定の関係にあるか(あったら1無ければ0) 3. 全ての単語間の関係で平均 5
実験設定 • BERT – 12 layer 768 dimension – AdamW optimizer – 文の15%がマスクされているマスク予測問題を解く • データ – Book Corpus – English Wikipedia • 単語間の関係 – Spacyをラベルとする 6
SASを促進/抑制する • 通常の𝐵𝐸𝑅𝑇𝑏𝑎𝑠𝑒 の学習にsyntactic regularizerを導入 – 修飾関係のある単語間(𝑥𝑖 , 𝑥𝑗 )に貼られるattentionにペナルティを加える • 𝜆 が負のもの𝐵𝐸𝑅𝑇𝑆𝐴𝑆+ – SASの形成が促進される • 𝜆 が正のもの𝐵𝐸𝑅𝑇𝑆𝐴𝑆− – SASの形成が抑制される 7
評価 • Evaluating on BLiMP – 英語のさまざまな文法現象に関する知識を評価するベンチマーク • UAS – モデル内部で文法を捉えているかどうかのメトリクス • Fine-tuning on GLUE – 言語系タスクでよく使われるベンチマーク • 相転移の検知 – あるメトリクスを𝑓(𝑡)としたときに以下の式で相転移を検知する 8
2段階の相転移 • 𝐵𝐸𝑅𝑇𝑏𝑎𝑠𝑒 はまず文法構造を捉えている内部構造(SAS)の形成の急激な変化が始まる (△:UASの相転移点) • その後,UASがプラトーになると(モデルが内部で文法を獲得すると) 同時にモデルのパフォーマンスの急激な向上が始まる( :Accの相転移点) → モデル内部の文法構造の獲得がモデルの能力の向上に寄与している △ 内部構造(UAS)の相転移 モデルの能力(Acc)の相転移 9
単純性バイアス TwoNNという手法を用いてモデルの複雑性を評価 • 構造の発現の直前に急激にモデルが単純化している – 学習初期ではモデルは単純な特徴から学習しやすい傾向にあり,[4,5,6] これを単純性バイアスと呼ぶ • 直感的にはSAS(人間が解釈可能)の獲得には単純性が必要であるということ △ 内部構造(UAS)の相転移 モデルの能力(Acc)の相転移 10
SASを促進/抑制する • SASの形成を促進した𝐵𝐸𝐴𝑇𝑆𝐴𝑆+は,𝐵𝐸𝑅𝑇𝐵𝑎𝑠𝑒に比べてUASの向上は早く起きるが 性能は悪化してしまう • SASの形成を抑制した𝐵𝐸𝐴𝑇𝑆𝐴𝑆−は,UASは全く向上しないがある程度性能は上がり, 学習初期,Lossは𝐵𝐸𝑅𝑇𝐵𝑎𝑠𝑒, 𝐵𝐸𝐴𝑇𝑆𝐴𝑆+よりも早く下がる → SASを獲得しようとするとLossが停滞する つまりSASはニューラルネットワークの他の特性と競合している? 11
学習初期段階でのSASの抑制 • 学習の途中(3k)までSASを抑制し,その後正則化の係数を0にする → 元のモデルよりも精度が向上 • また学習の後半で抑制を解放してもUASが上がらない (SASが回復することはなかった) 12
Discussion • 学習ダイナミクスと単純性バイアス – 情報が豊富な事前分布(帰納バイアス)に基づいたモデルは大量のデータで学習した一般的な モデルに劣る [7] というBitter Lessonと整合 – 自然現象に対する人間に認識はすごく単純なため,それが訓練の初期に存在すること (e.g., SAS)はニューラルネットにとってはネガティブなシグナルになりうる – またカリキュラム学習が大規模でうまくいかない[8]理由も,単純なデータを学習初期に導入 することで初期のlossは下がるが,学習後半のパフォーマンスが損なわれるからでは? • 相転移 – 普通のvalidation lossだと相転移は見られないが,今回のような現実的な設定でも起きうる – また,実験では相転移中に正則化変更することが 一番性能を悪化させた – 相転移中のoptimizerの誤操作は収束時のパフォーマンス にも影響を与え,スムーズにlossが上がる場合は このような相転移を見逃しやすい 13
まとめ • モデルが内部でどれくらい文法を理解しているかを定量化し, 2段階(構造の発現とパフォーマンスの向上)の学習ダイナミクスを示した. • SASに関する正則化項を導入し,SASの必要性を示し, 他のネットワーク特性と競合していること示した. • 一時的にSASを抑制することで,モデルのパフォーマンスが上がることを示し た. 14
参考文献 [1] Emergent linguistic structure in artificial neural networks trained by self-supervision [2] Analyzing multi-head self-attention: Specialized heads do the heavy lifting, the rest can be pruned [3] What Does BERT Look At? An Analysis of BERT’s Attention, [4] What shapes feature representations? Exploring datasets, architectures, and training [5] The Pitfalls of Simplicity Bias in Neural Networks [6] GD on Neural Networks Learns Functions of Increasing Complexity [7] The bitter lesson. [8] 15