UnityとゲームAIと将棋

Unity、Pythonを中心にゲーム開発やゲームAI開発の技術メモ等、たまに将棋も

【機械学習】RNNの内部計算メモ

本記事の概要

前回のFNNの内部計算の記事

tsubasa-alife.hatenablog.com

の続きで、本記事ではRNNの内部計算について備忘録的にメモしておきます。

RNNの計算グラフ

RNNの計算グラフ

※各パラメタ・要素の説明
 x_t : 時刻tにおける入力
 W_{ih} : 入力層と中間層の結合重み
 W_{dh} : 前時刻t-1の中間層と現時刻tの中間層の結合重み
 b_h : 中間層におけるバイアス
 h_t : 時刻tにおける中間層への入力。内部状態。
 tanh : 活性化関数
 d_t : 時刻tにおける中間層の出力。コンテキスト。
 W_{do} : 中間層と出力層の結合重み
 b_o : 出力層におけるバイアス
 y_t : 時刻tにおける出力値

RNNの内部計算

基本的な計算はFNNの時と同じですが、異なる部分としては中間層の計算を行う際に前時刻の中間層の出力を利用するという部分があります。過去の中間層の出力を現在での中間層における計算に利用することで、過去の履歴を考慮した出力を行うことができます。このような構造を持っているためRNNは時系列データなどの予測等に利用されることが多いです。

誤差関数


E=\dfrac{1}{2}\left(y_t-\hat{y}_t\right)^2

 \hat{y}_t は時刻tにおける教師データ

順方向計算


h_t=W_{ih}x_t+W_{dh}d_{t-1}+b_h \\
d_t=tanh(h_t) \\
y_t=tanh(W_{do}d_t+b_o)

誤差逆伝播計算


\dfrac{\partial E}{\partial W_{do}}=\dfrac{\partial E}{\partial y_t}\cdot\dfrac{\partial y_t}{\partial W_{do}}=(y_t-\hat{y_t})(1-y_t^2)d_t \\
\dfrac{\partial E}{\partial b_o}=\dfrac{\partial E}{\partial y_t}\cdot\dfrac{\partial y_t}{\partial b_o}=(y_t-\hat{y}_t)(1-y_t^2) \\
\dfrac{\partial E}{\partial d_t}=\dfrac{\partial E}{\partial y_t}\cdot\dfrac{\partial y_t}{\partial d_t}+\dfrac{\partial E}{\partial h_{t+1}}\cdot\dfrac{\partial h_{t+1}}{\partial d_t}=(y_t-\hat{y_t})(1-y_t^2)W_{do}+\dfrac{\partial E}{\partial h_{t+1}}\cdot W_{dh} \\
\dfrac{\partial E}{\partial h_t}=\dfrac{\partial E}{\partial d_t}\cdot\dfrac{\partial d_t}{\partial h_t}=\bigl\{(y_t-\hat{y_t})(1-y_t^2)W_{do}+\dfrac{\partial E}{\partial h_{t+1}}\cdot W_{dh}\bigr\}(1-d_t^2) \\
\dfrac{\partial E}{\partial W_{ih}}=\dfrac{\partial E}{\partial h_t}\cdot\dfrac{\partial h_t}{\partial W_{ih}}=\bigl\{(y_t-\hat{y_t})(1-y_t^2)W_{do}+\dfrac{\partial E}{\partial h_{t+1}}\cdot W_{dh}\bigr\}(1-d_t^2)x_t \\
\dfrac{\partial E}{\partial W_{dh}}=\dfrac{\partial E}{\partial h_t}\cdot\dfrac{\partial h_t}{\partial W_{dh}}=\bigl\{(y_t-\hat{y_t})(1-y_t^2)W_{do}+\dfrac{\partial E}{\partial h_{t+1}}\cdot W_{dh}\bigr\}(1-d_t^2)d_{t-1} \\
\dfrac{\partial E}{\partial b_h}=\dfrac{\partial E}{\partial h_t}\cdot\dfrac{\partial h_t}{\partial b_h}=\bigl\{(y_t-\hat{y_t})(1-y_t^2)W_{do}+\dfrac{\partial E}{\partial h_{t+1}}\cdot W_{dh}\bigr\}(1-d_t^2)

いわゆるBPTT(Backpropagation Through Time)と呼ばれる計算です。

パラメタ更新式


W_{do}\leftarrow W_{do}-\alpha\cdot\dfrac{\partial E}{\partial W_{do}} \\
b_o\leftarrow b_o-\alpha\cdot\dfrac{\partial E}{\partial b_o} \\
W_{ih}\leftarrow W_{ih}-\alpha\cdot\dfrac{\partial E}{\partial W_{ih}} \\
W_{dh}\leftarrow W_{dh}-\alpha\cdot\dfrac{\partial E}{\partial W_{dh}} \\
b_h\leftarrow b_h-\alpha\cdot\dfrac{\partial E}{\partial b_h}