[DL輪読会]Non-Autoregressive Machine Translation with Latent Alignments

>100 Views

May 29, 20

スライド概要

2020/05/22
Deep Learning JP:
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

(ダウンロード不可)

関連スライド

各ページのテキスト
2.

緒⾔ 2

3.

論⽂情報 Aligned Cross Entropy for Non-Autoregressive Machine Translation Ghazvininejad et al., arXiv 2020 https://arxiv.org/abs/2004.01655 Non-Autoregressive Machine Translation with Latent Alignments Saharia et al., arXiv 2020 https://arxiv.org/abs/2004.07437 3

4.

選定理由 Alignmentと動的計画法を使うことでNon-Autoregressive MTの課題にうまく取り組めることを⽰ している 4

5.

概要 Non-Autoregressive Machine Translation 全てのトークンを⼀度で⽣成するNon-Autoregressive(NAR)モデルは⾼速に推論できる⼀ ⽅、翻訳品質は⼀般的にAutoregressive(AR)モデルに劣る Alignmentを使⽤した⼿法 予測トークンの絶対位置に強いペナルティを課すCross Entropyではなく、Alignmentを使⽤し た損失関数(Aligned Cross Entropy)を提案し性能が改善 Alignmentを直接⽣成することでNARモデル特有の困難さを回避し性能を改善 5

6.

背景 6

7.

Sequence-to-sequence model By @m3yrin. Based on dair.ai’s template ( https://github.com/dair-ai/ml-visuals ) ⼀般的なSeq2seqモデルのDecoderは逐次的(Autoregressive)にトークンを⽣成する 7

8.

Autoregressive(AR) decoding T P (Y ∣X, θ) = ∏ p (yt ∣y <t , X, θ) t=1 シンプルで⾼い性能を出す⼀⽅、前ステップの⼊⼒をもとに⽣成するので並列化できない Autoregressive decoding is the only part of sequence-to-sequence models that prevents them from massive parallelization at inference time. - Libovický and Helcl, 2018 8

9.

Non-autoregressive(NAR) decoding 翻訳品質を落とすことなく、トークン⽣成を並列化させることを⽬指す 最近様々な⼿法が提案された Iterative refinement CTC models Insertion based methods Edit-based methods Masked language models Normalizing flow models 9

10.

Non-Autoregressive Transformer (NAT) Gu et al., "Non-Autoregressive Neural Machine Translation" ICLR 2018 Salesforce Research Transformerを使⽤ ⼊⼒トークンあたりの出⼒トークンの⽣産数(Fertility)をMLPで予測 10

11.

Non-Autoregressive Transformer (Gu et al., 2018) 11

12.

Conditional Masked Language Models (CMLM) Ghazvininejad et al., "Mask-Predict: Parallel Decoding of Conditional Masked Language Models" EMNLP 2019 Facebook AI Research BERT等の事前学習のように、Maskされたトークンの予測を使って⽂章を⽣成する 最初のデコードはすべてのトークンをMaskしてdecoderへ⼊⼒、2回⽬からは予測のProbability が低いトークンを置き換えていく ターゲットの⻑さはEncoderの出⼒を使いMLPで予測 12

14.

CMLM / Results CMLMはIterationが増えるごとにARの結果に近づく ARのほうが依然として性能はいい 14

15.

NARモデルの制限 Non-Autoregressive decodingには概ね2つの難しさがあるとされる 1. 出⼒トークン間の条件付き独⽴性 2. 事前のターゲット⻑の予測 15

16.

出⼒トークン間の条件付き独⽴性 ソース列X = {x1 , ..., xn }とターゲット列Y = {y1 , ..., yT }について Autoregressive decoding T P (Y ∣X, θ) = ∏ p (yt ∣y <t , X, θ) t=1 Non-Autoregressive decoding T P (Y ∣X, θ) = ∏ p (yt ∣X, θ) t=1 16

17.

Multimodality (Gu et al., 2018) ⼀つの⽂に対する翻訳⽂は⼀つと限らない 英"Thank you." -> 独“Danke.”, “Danke schon.”, or “Vielen Dank.” トークン間の独⽴性のため、Decoderの各々のユニットは他のユニットが何を選んだのか分から ない → “Danke Dank.”, “Vielen schon.” も許可される 17

18.

Token Repetition Problem Multimodalityの影響は単語の繰り返しとしてNARモデルでは顕著に現れることから、"Token Repetition Problem"とも呼ばれる 画像 : Tu et al., 2020 (https://arxiv.org/abs/2005.00850) 18

19.

予測⻑ごとのRescoring 多くのNARのモデルでは、事前にターゲット⻑さの予測が必要 性能を出すために、複数のターゲット⻑さ候補でdecodingし、その中で最も良い品質のものを 出⼒とする 並列化できるとはいえ、計算量は無視できない 19

20.

CMLM (Ghazvininejad et al., 2019) 20

21.

Alignmentを使ったNAR decoding 21

22.

ここでのAlignmentとは ⼊⼒(ソース、モデル出⼒)列とターゲット列との間のマッピング(関数) ターゲット列に"空⽩"や繰り返し⼊れて、⼊⼒列と⻑さを揃えたもの 揃えるとは... ⼊⼒列の⻑さが10 ターゲット列がy = (A, A, B, C, D) 取りうるAlignmentの1つは、 a = (_, A, A, _, A, B, B, C, _, D) ( _ は空⽩⽂字) 22

23.

Alignmentとターゲットの関係 β(y) : ターゲットy に対して起こりうる全てのアライメントを返す関数 β −1 (a) : ターゲットを回復する関数 すべての連続した繰り返しと、空⽩トークンを除去 Input: (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) ↕ 同じ⻑さ Alingment: (_, A, A, _, A, B, B, C, _, D ) ↓ β^(-1) Target: , ↑ β (A, A, B, C, D) 23

24.

#1 Aligned Cross Entropy for Non-Autoregressive Machine Translation Marjan Ghazvininejad, Vladimir Karpukhin, Luke Zettlemoyer, Omer Levy Facebook AI Research https://arxiv.org/abs/2004.01655 24

25.

Cross Entropyでの学習 NARモデルはほとんどがCross Entropyで学習される Cross Entropyは予測されたトークンの位置を厳しくみてしまい、僅かな編集距離の⽂章に対し ても⼤きなペナルティを課す NARモデルは前のトークンをみながらデコードができず、(直感的には)絶対位置に対するペナルテ ィよりも単語の⽋損へのペナルティを重視すべき 25

26.

Aligned Cross Entropy (AXE) の導⼊ モデル予測分布列P = P1 , ..., Pm とターゲット列Y = Y1 , ..., Yn のAlignmentに対して Cross Entropyを考える Alignment α をターゲット列と予測分布列をマップする関数とする α : {1, ..., n} → {1, ..., m}. また、順序の単調性を仮定する i ≤ j, if α(i) ≤ α(j). 26

27.

AXE loss 1/2 あるAlignment αにおけるConditional AXE AXE (Y1 , … , Yn , P1 , … , Pm ∣α) = n − ∑i=1 log Pα(i) (Yi ) − ∑k∈{1…m}\{α(1),…α(n)} log Pk (ε) ただし、εは空⽩トークン (第⼀項) Yi は α(i) 番⽬の予測で対数尤度を計算 (第⼆項) Alignmentが当たっていない予測は空⽩トークンの対数尤度を計算 27

29.

AXE Loss 2/2 最終的なAXE Lossは、予測確率列とターゲット列の間で最適なAlignmentで計算 AXE (Y1 , … , Yn , P1 , … , Pm ) = minα AXE (Y1 , … , Yn , P1 , … , Pm ∣α) = n minα(1)…α(n) (− ∑i=1 log Pα(i) (Yi ) − ∑k∈{1…m}\{α(1),…α(n)} log Pk (ε)) s.t. 1 ≤ α(1) ≤ α(2) ≤ α(3) … ≤ α(n) ≤ m 絶対位置ではなく、相対的な順序と語彙的な⼀致をみる NARに対してより正確な学習シグナルをもたらす 29

30.

DP Algorithm 1/3 AXE Lossを得るために、ターゲット列からモデル予測列への最も良いAlingmentを動的計画法 (Dynamic Programming, DP)で探す Y = Y1 , ..., Yn とP = P1 , ..., Pm に対して(n + 1) × (m + 1)の⾏列Aを考える Aの要素Ai,j はY1:i = Y1 … Yi とP1:j = P1 … Pj で計算されたAXE Lossに対応 30

31.

DP Algorithm 2/3 以下の3つの操作の単位(Align, Skip Prediction, Skip Target)を⾏いながら、再帰的に計算を繰 り返す Skip Targetのδ は1以上で使⽤し、Skip targetの操作を減らすためのペナルティとなる 31

32.

DP Algorithm 3/3 3つの操作の中で最⼩のConditional AXEを選びながらAを左上から埋めていく。 最後の要素An,m まで計算が終わると、 An,m が最も最適なAlignmentにおける AXE Lossの値となる 32

33.

AXE 計算コスト Aの逆対⾓線上の要素は並列化でき、計算量はO(n + m)程度 AXEの計算時間はCross Entropyを使った場合に⽐べ1.2倍程度の増加に収まるとのこと 33

34.

実験 CMLMをAXEで学習 ⼀部を除きBLEUで評価 ターゲット⻑さの候補数は5 ARモデルからのKnowledge Distillationも使⽤ ARモデルのサイズは先⾏研究に合わせる 34

35.

結果 (AXE vs Cross Entropy) AXEを使うだけで平均で5.2ほどBLEUが上がる 35

36.

結果 (vs Purely Non-Autoregressive Models) Iterativeでない(Purely Non-Autoregressive)モデルの中ではSota 36

37.

結果(蒸留なし) 37

38.

“Hedge their bets” Fig.3は、⼀つのトークンの予測確率がどの ように広がるかを⽰す。 ⻑い系列の予測では、AXEに⽐べCross Entropyにおいて予測のピークは広がる Cross Entropyで学習されたモデルは位置 がずれていても⼤きなペナルティとならない ように、周辺の位置にも確率を与える 38

39.

AXEによる Multimodalityの軽減 単語の繰り返しはCross Entropyで学習し たものに⽐べ1/12に低下 39

41.

#2 Non-Autoregressive Machine Translation with Latent Alignments Chitwan Saharia, William Chan, Saurabh Saxena, Mohammad Norouzi Google Research, Brain Team https://arxiv.org/abs/2004.07437 41

42.

Latent Alignment Model Alignmentを潜在変数(Latent Alignment)として使う log pθ (y∣x) = log ∑ pθ (a∣x) a∈β(y) このままではAlignmentの組み合わせが⼤きすぎて計算できない 42

43.

Connectionist Temporal Classification (CTC) トークン間の強い条件付き独⽴性を仮定 pθ (a∣x) = ∏ p (ai ∣x; θ) i log pθ (y∣x) = log ∑ ∏ p (ai ∣x; θ) a∈β(y) i 動的計画法を使⽤してLatent Alignmentを正確に周辺化できる Forward-Backward Algorithm (Graves et al., 2006)を使⽤ 43

44.

Imputer 固定数のステップで⽣成を⾏う反復⽣成モデル ⽣成ステップ内での条件付き独⽴、⽣成ステップ間では依存 ~で条件づけられた次のアライメントaを、以下のようにモデル化 1 step前のアライメントa ~, x) = ∏ p (ai ∣a ~, x; θ) pθ (a∣a i 44

45.

Imputer 対数尤度の下限を作ることができ、CTCと同様にして解ける 45

46.

Latent Alignment Modelの強み CTCやImputerのように潜在的にAlignmentを使う事で、NAR特有の制限を回避できる トークンの繰り返しはAlignmentとして扱われる Token Repetitionはβ −1 (a)で機械的に折りたたまれる Alignmentを出⼒するモデルであり、Alignmentはソース列と同じ⻑さであることから、列の⻑さ の必要はない Target Length Predictionは不要 46

47.

AlignmentをNMTで使うために CTC(やImputer)は⾳声認識のようなタスクで使われてきた Alignmentについては以下のような仮定をおいて使⽤される i. Alignmentとターゲットの間に単調なマッピングが存在する ii. ターゲットの⻑さは、ソースの⻑さ以下である。すなわち∣x∣ ≥ ∣y∣ 47

48.

単調性の問題 機械翻訳ではソースとターゲットとの間に単調な順序関係は存在しない Transformerはターゲットとほぼ単調な関係になるようなAlignmentを出⼒できると仮定 48

49.

⻑さの問題 ソースの⻑さ以上のターゲットは普通にある ソースを元の⻑さの s 倍にアップサンプリングして使⽤ ⼊⼒トークンの埋め込みx ∈ R∣x∣×d (dは埋め込み次元)を線形変換するだけ x′ ∈ Rs⋅∣x∣×d 49

50.

モデル構造 Encoder-Decoderの構造を持たず、シンプルな構造となる 50

51.

実験 モデルサイズはBase Transformerと同じ ARモデルからのKnowledge Distillationも使⽤ ARモデルのサイズは先⾏研究に合わせる CTC(Libovický and Helcl, EMNLP 2018)との違いは蒸留の有無だけ 51

52.

Single step decoding 52

53.

Single step decoding Single stepでの性能はSota トークンの繰り返し率も有意に低い (それはそう...) 53

54.

Imputer/Iterative decoding 54

55.

蒸留の効果 蒸留の効果はCTCで特に⼤きい 反復回数が⼤きいほど効果は限定 的 55

56.

まとめ 予測トークンの絶対位置に強いペナルティを課すCross Entropyではなく、Alignmentへの損失 関数(Aligned Cross Entropy)を提案し性能が改善 Alignmentを直接⽣成することでNARモデル特有の困難さを回避し性能を改善 56

57.

雑感 繰り返しが起こってしまう現象をAlignmentとして扱うのは賢い 同じような考え⽅は広く使えるのでは ARモデルをAXEで学習しても効果があるのでは Cross Entropyは位置ずれに厳しすぎた 翻訳機能を提供してる企業が中⼼になって研究が進んでいるようにみえる Google, Facebook AI Research, Salesforce 57

58.

参考資料 ちょっと変わったDecodingの⽅法 (*以前書いたブログ記事) https://qiita.com/m3yrin/items/7a6df0d7f5d44f0b5efc 58

59.

予備 59

60.

繰り返しの削除による効果 単純に繰り返しを削除するだけでも効果が出る? 先⾏研究にて繰り返し語の機械的な削除が試されている Lee et al., EMNLP 2018. "Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement" 単なる削除では、BLEUの1ポイントほどの改善にとどまった 60