>100 Views
June 23, 24
スライド概要
AI・機械学習を勉強したい学生たちが集まる、京都大学の自主ゼミサークルです。私たちのサークルに興味のある方はX(Twitter)をご覧ください!
2024年度前期輪読会#7「ゼロから作る Deep Learning」 5章 誤差逆伝播法 5.3~5.4 京都大学理学部二回生 駒井暁 0
誤差逆伝播法 目次 1. 加算ノードの逆伝播 2. 乗算ノードの逆伝播 3. 加算・乗算レイヤの実装 4. まとめ 1
1. 加算ノードの逆伝播 2
1.加算ノードの逆伝播 表式 ● 2つの入力の和を出力とする局所的な計算 (加算ノード)で逆伝播は のとき なので、下流側の微分値に1を掛けた値が上 流側に伝播していく。 3
2. 便利なテンプレ集 計算例 ● 順伝播が入力値10と5の加算の場合、逆伝播 は下流側の信号(今回は1.3)をそのまま上流 側に伝えるだけ。 ● どの加算でも逆伝播は同じ計算になる。 4
2. 乗算ノードの逆伝播 5
2. 乗算ノードの逆伝播 表式 ● 2つの入力の和を出力とする局所的な計算 (加算ノード)で逆伝播は のとき なので、下流側の微分値に順伝播時の入力値を ひっくり返した値を掛けた値が上流に伝播して いく。 6
2. 乗算ノードの逆伝播 例 ● 順伝播が入力値10と5の乗算の場合、逆伝播 は出力側の信号である1.3に、5と10をそれ ぞれ掛けた値が上流に伝わる。 ● 加算ノードとは異なり、微分値が順伝播での 入力値に依ってしまう。 7
3. 加算・乗算レイヤの実装 8
3. 加算・乗算レイヤの実装 リンゴ2個とみかん3個の買い物についての下図の状況での逆伝播 を実装してみよう。最終的な支払金額は加算と乗算の組み合わせ で成り立っている。 右向きの矢印:順伝播の信号 左向きの矢印:逆伝播の信号 9
3. 加算・乗算レイヤの実装 乗算レイヤの実装 class MulLayer: def __init__(self): self.x = None self.y = None def forward(self, x, y): self.x = x self.y = y out = x * y #後で値を代入するため初期化 #順伝播の信号を記憶 return out def backward(self, dout): dx = dout * self.y dy = dout * self.x return dx, dy 10
3. 加算・乗算レイヤの実装 加算レイヤの実装 class AddLayer: def __init__(self): pass #初期化を行う必要はない def forward(self, x, y): out = x + y return out def backward(self, dout): dx = dout * 1 dy = dout * 1 return dx, dy return dx, dy 11
3. 加算・乗算レイヤの実装 買い物における計算グラフ apple = 100 apple_num = 2 orange = 150 orange_num = 3 tax = 1.1 # layer mul_apple_layer = MulLayer() mul_orange_layer = MulLayer() add_apple_orange_layer = AddLayer() mul_tax_layer = MulLayer() # forward apple_price = mul_apple_layer.forward(apple, apple_num) # (1) orange_price = mul_orange_layer.forward(orange, orange_num) # (2) all_price = add_apple_orange_layer.forward(apple_price, orange_price) price = mul_tax_layer.forward(all_price, tax) # (4) # (3) # backward dprice = 1 dall_price, dtax = mul_tax_layer.backward(dprice) # (4) dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price) dorange, dorange_num = mul_orange_layer.backward(dorange_price) # (2) dapple, dapple_num = mul_apple_layer.backward(dapple_price) # (1) # (3) print("price:", int(price)) print("dApple:", dapple) print("dApple_num:", int(dapple_num)) print("dOrange:", dorange) print("dOrange_num:", int(dorange_num)) print("dTax:", dtax) 12
誤差逆伝播法 4.まとめ まとめ1 加算ノードでの逆伝播は下流の信号はそのまま上流に送られる。 乗算ノードでの逆伝播は下流の信号に、順伝播値の入力信号をひっくり返した値 まとめ2 を掛けた値が上流に送られる。 加算ノードとは異なり、順伝播時の値を覚えておかなければならない。 まとめ3 全体で見れば複雑な計算も局所的な問題に切り替えれば、対応するレイヤを実装 することで計算を単純化できる。 13
14