前回の記事の続きで,論文 Li et al., Building Program Vector Representations for Deep Learning, CoRR abs/1409.3358, 2014 のパラメータ更新式の導出について書く. 多くの論文では,いちいち,この程度の式の導出過程などは書かない. 今回は,自分でプログラムを組むためにパラメータの勾配を計算したので,忘れないように導出過程をメモっておく.

確率的勾配降下法

ニューラルネットでは,多くの場合,パラメータの更新は確率的勾配降下法 (stochastic gradient descent; SGD) で行う.この方法では,パラメータ更新式をループの中で繰り返し適用することで,徐々にパラメータを最適解へと近づけていく. 最適解を解析的に求めることができないような非線形最適化問題では,よく使われる方法の1つである. 最小化すべき誤差関数を $E$,第 $\tau$ 回目のループにおけるパラメータを $\theta^{(\tau)}$ とおくと, パラメータの更新式は

である.$\eta > 0$ は学習率 (learning rate) であり,確率的勾配降下法では,ループの回数を増やす度に, 徐々に小さくする必要がある(適切に小さくしていく方法は色々ある).

ベクトルや行列での微分

今回はパラメータが,$\theta = (\mathbf{vec}(\cdot), \bm{W} _\mathrm{L}, \bm{W} _\mathrm{R}, \bm{b})$ であるから,基本的には次のように,誤差関数をベクトルや行列で偏微分するだけである.

しかし,ニューラルネットについては,途中計算にベクトルの要素単位の積とか変な式が現れるので, 無難に要素単位で偏微分した後に,ベクトル表記に直したほうが,間違いが少ない.

要素単位の RNN 計算式

というわけで,まずは前回の記事の式を要素単位の表記に書き下してみる. まず,RNN の出力 $\bm{y}$ の第 $i$ 成分は

と書ける($t = (p,c_1,c_2,\dots,c_M)$ は親を $p$,子を $c_1,c_2,\dots,c_M$ とする構文木の部分木). ただし,$a_i$ は活性 (activation),$b_i$ はバイアス $\bm{b}$ の第 $i$ 成分, $\mathrm{vec} _j(c _m)$ は特徴ベクトル $\mathbf{vec}(c_m)$ の第 $j$ 成分であり, $W _{m,i,j}$ は重み行列 $\bm{W} _m$ の第 $(i,j)$ 成分を表していて,

という式で与えられる.最小化すべき誤差関数は

であり(L2 正則化項は後で追加),$d$ は次のように与えられる.

求めるべき勾配は,以下の4種類

ただし,$n,k = 1,\dots,D$ であり,$s$ は全ての非終端記号である.

誤差の計算

あとの計算式を綺麗に書くために

を定義しておく.これを計算してみると,次のようになる.

ちなみに,論文では $f = \tanh$ である.tanh の微分は,

となる.$y_n(t) = \tanh(a_n(t))$ であるから,結局

である.この式をベクトル表記に直すと,

となる.ただし,$\otimes$ はベクトルの要素単位の積を表す.

バイアスに対する勾配

まず,関数 $\max(0,x)$ の微分について考えてみる. $x > 0$ では $\max(0,x) = x$ なので微分すると 1 になり, $x < 0$ では $\max(0,x) = 0$ なので微分すると 0 になる. したがって,導関数はステップ関数になる($x=0$ では微分できないが, 実用上は 0 か 1 など適当な値と見なしても良い).

そうすると,パラメータ $b_n$ について,

となる.ここで,合成関数の微分より,

であるため,ベクトル表記に直すと,

\begin{align} \frac{\partial E}{\partial \bm{b}} & = H(0, \Delta + d(t _+) - d(t _-)) \left( \frac{\partial d(t _+)}{\partial \bm{b}} - \frac{\partial d(t _-)}{\partial \bm{b}} \right) \
\frac{\partial d}{\partial \bm{b}} & = \bm{\omega} \end{align
}

重み行列の勾配

先ほどと同様に,合成関数の微分より,

であるから,

\begin{align} \frac{\partial E}{\partial \bm{W}_{\mathrm{L}}} & = H(0, \Delta + d(t_+) - d(t_-)) \left( \frac{\partial d(t_+)}{\partial \bm{W}_{\mathrm{L}}} - \frac{\partial d(t_-)}{\partial \bm{W}_{\mathrm{L}}} \right) \
\frac{\partial d}{\partial \bm{W}_{\mathrm{L}}} & = \bm{\omega} \left( \sum^M_{m=1} \beta_m (1 - \alpha_m) \mathbf{vec}(c_m) \right)^\top \end{align
}

となる.同じように,$\bm{W}_\mathrm{R}$ についても,以下の式が得られる.

\begin{align} \frac{\partial E}{\partial \bm{W}_{\mathrm{R}}} & = H(0, \Delta + d(t_+) - d(t_-)) \left( \frac{\partial d(t_+)}{\partial \bm{W}_{\mathrm{R}}} - \frac{\partial d(t_-)}{\partial \bm{W}_{\mathrm{R}}} \right) \
\frac{\partial d}{\partial \bm{W}_{\mathrm{R}}} & = \bm{\omega} \left( \sum^M_{m=1} \beta_m \alpha_m \mathbf{vec}(c_m) \right)^\top \end{align
}

ちなみに,L2 正則化項を付けた場合

の勾配は

で計算できる(任意の行列 $\bm{A}$ について $\partial \|\bm{A}\|^2_F / \partial \bm{A} = 2 \bm{A}$ なので). 重み行列に対する勾配以外は,L2 正則化の影響を受けないので,同じ式で計算できる.

特徴ベクトルの勾配

最後に,特徴ベクトルに対する勾配について考える. $p, c_1, \dots, c_M$ は互いに異なるシンボルとは限らないため,少し注意が必要である. まず,微分の連鎖法則より,

\begin{align} \frac{\partial d}{\partial \mathrm{vec}_n(s)} & = \sum^D_{i=1} \frac{\partial d}{\partial a_i} \frac{\partial a_i}{\partial \mathrm{vec}_n(s)} - (y_n(t) - \mathrm{vec}_n(s)) \frac{\partial \mathrm{vec}_n(p)}{\partial \mathrm{vec}_n(s)} \
& = \sum^D_{i=1} \omega_i \left( \sum^M_{m=1} \beta_m W_{m,i,n} \frac{\partial \mathrm{vec}_n(c_m)}{\partial \mathrm{vec}_n(s)} \right) - (y_n(t) - \mathrm{vec}_n(s)) \frac{\partial \mathrm{vec}_n(p)}{\partial \mathrm{vec}_n(s)} \
& = \sum^M_{m=1} \left( \beta_m \frac{\partial \mathrm{vec}_n(c_m)}{\partial \mathrm{vec}_n(s)} \sum^D_{i=1} W_{m,i,n} \omega_i \right) - (y_n(t) - \mathrm{vec}_n(s)) \frac{\partial \mathrm{vec}_n(p)}{\partial \mathrm{vec}_n(s)} \end{align
}

であるが,$\partial \mathrm{vec}_n(p)/\partial \mathrm{vec}_n(s)$ は $p=s$ に限って 1 になり,それ以外の時は 0 なので,クロネッカーのデルタ

を使って,

と書ける.これを基に,ベクトル表記に直すと,

\begin{align} \frac{\partial E}{\partial \mathbf{vec}(s)} & = H(0, \Delta + d(t_+) - d(t_-)) \left( \frac{\partial d(t_+)}{\partial \mathbf{vec}(s)} - \frac{\partial d(t_-)}{\partial \mathbf{vec}(s)} \right) \
\frac{\partial d}{\partial \mathbf{vec}(s)} & = \sum^M_{m=1} \beta_m \delta_{c_m,s} \bm{W}_m^\top \bm{\omega} - (\bm{y}(t) - \mathbf{vec}(s)) \delta_{p,s} \end{align
}

となる.クロネッカーのデルタが一見扱いにくそうに見えるけど,次のように,簡単なアルゴリズムで計算できる.

  1. 全ての非終端記号 $s$ について,$\dfrac{\partial d}{\partial \mathbf{vec}(s)} \gets \bm{0}$ として初期化,
  2. $m = 1,2,\dots,M$ について, $\dfrac{\partial d}{\partial \mathbf{vec}(c_m)} \gets \dfrac{\partial d}{\partial \mathbf{vec}(c_m)} + \beta_m \bm{W}_m^\top \bm{\omega}$,
  3. $\dfrac{\partial d}{\partial \mathbf{vec}(p)} \gets \dfrac{\partial d}{\partial \mathbf{vec}(p)} - (\bm{y}(t) - \mathbf{vec}(s))$.

実験

上の数式と前回の記事の内容を愚直にコーディングすれば,プログラムが作れる. 実際に作って,論文のデータセットの一部で学習させて求めた30次元特徴ベクトルを2次元空間にマッピングしたのが以下の図 (ごめんなさい,文字が重なって見難い)

![特徴ベクトルの分布](/img/Li14-vecs-plot.png)

論文のクラスタリング結果と概ね一致する結果になった. UnaryOp,FunCall,BinaryOp とか近くにあるし,だいだいいい感じ (EnumeratorList とか意味わかんないのもあるけど). 論文と少し違うところもあるけど,まあ,データセット全部使ったわけじゃないし,次元削減しているし, 論文の付属プログラムが論文と少し違う式を計算しているってのもあると思う. https://sites.google.com/site/learnrepresent からダウンロードしたプログラムでは, 活性化関数が tanh(x) じゃなくて tanh(2x) になっているし, 勾配計算のステップ関数が無視されていたり,バイアスと特徴ベクトルまで正則化項に含められていた. しかも,行列とかをシリアライズした1次元配列で取り回していて, 必要なときに reshape して行列(2次元配列)に変換するようにコーディングしてあったので, めっちゃ読みにくかった.

ちなみに,次元削減には主成分分析(principal component analysis)を使っている. 主成分分析についても,暇があったら記事を書きたい.

次元削減した後のデータは次の表のとおり

非終端記号 x y
root -1.0155 1.7688
FuncDef -2.5879 0.5866
Decl -1.7025 3.1864
Compound -5.4227 2.4784
FuncDecl -0.7950 1.4347
TypeDecl -2.8924 3.9249
IdentifierType -0.0477 10.4101
FuncCall -4.8774 0.6540
If -4.3813 0.1543
ID -9.9957 -3.9167
ExprList -6.4877 1.1682
Constant -11.2889 -1.9098
UnaryOp -4.8532 -0.5008
BinaryOp -6.8972 1.1328
Assignment -5.1170 0.6577
For -4.3082 0.6070
Break 0.1827 1.5458
Return -0.1384 -0.2223
ArrayDecl -1.6604 1.6452
InitList -0.5112 0.3707
ArrayRef -7.0429 0.9652
While 0.2959 0.6021
DeclList 0.1162 0.1366
ParamList 0.3104 0.8814
Cast -1.3778 -0.5955
Typename -0.0410 0.7763
EmptyStatement 1.0458 0.3817
Continue 2.1161 0.0493
PtrDecl -0.4069 0.7698
Switch -0.8380 -0.6590
Case 0.3155 0.8079
DoWhile 0.7879 -0.4900
Struct 0.0659 1.2113
StructRef -3.3681 -0.8727
Typedef 2.0958 0.2052
TernaryOp -0.0626 1.5588
Label 0.2081 1.1971
Goto -0.6034 -0.7921
Default 1.1663 0.9182
Enum 0.3183 -1.6322
EnumeratorList 0.3477 -1.5040
Enumerator 1.1773 -0.8288
Union 1.5770 0.3978
CompoundLiteral 1.2759 0.8194