技術をかじる猫

適当に気になった技術や言語、思ったこと考えた事など。

ニューラルネットワーク各階層の勾配計算式

勾配計算式

w を重み、b をバイアス、E を誤差(損失関数出力)とするとこんな形状で定式化されてる。
この辺はいくつかの書籍見て、ようやっと飲み込めた感じ…。

数式を飲み込むのにはそれなりに時間を要したけど…。

  • 出力層
δk=Euk=Eykykukwjk=Ewjk=yjδk()bk=Ebk=δk()yj=Eyj=nr=1δrwjr()
  • 中間層
δj=Euj=yjykujwij=Ewij=yiδj()bj=Ebj=δj()yi=Eyi=mq=1δqwiq()

結局のところ、δn さえ算出できれば、残りの計算式は芋づるで計算できることになる。

求めた過程

原則的に、ニューロン(出力層含む)の作りは、入力 y と重み wjk + バイアス bk によるもので

uk=yjwjk+bkyk=f(uk)

でできてる。
出力層に至ってはここに損失関数を食わせて E (理想値との誤差)を取得する。
この時、勾配降下法を用いるなら、それら式を微分して、勾配を求めていくことになる。

ニューロンが学習するべき値は wb で、これらは行列なので偏微分でひとつづつ誤差への影響を求めていくことになる。

wjk=Ewjk=Eukukwjk...()

※連鎖律を適用。連鎖律については ここ 参照

ukwjkuk を展開すると

ukwjk=(mq=1yqwqk+bk)wjk=wjk(y1w1k+y2w2k...yjwjk...+ymwmk+bk)

偏微分なので、 wjk のかかってない項はすべて 0 なので、

ukwjk=wjk(y1w1k+y2w2k...yjwjk...+ymwmk+bk)=yj

次に Euk に焦点を当てて、こちらも連鎖律で展開して、

Euk=Eykykuk

これを δk とすると

δk=Eykykukwjk=yjδk

という形に持っていける。

bk も同様の手順で考えると

bk=Ebk=Eukukbk

uk を展開すると、偏微分したときに bk の1項しか残らず、最終的に 1 になるので、最終形態は

bk=δk

中間層出力の勾配も考えると、

yj=Eyj=nr=iEururyj

このうち、

uryj=(mq=1)yqwqr+bryj=yj(y1w1r+y2w2r+...yjwjr+...+ymwwr+br)=wjr(yjyj0)

ここで δr=Eurとすれば

yj=nr=1δrwjr

中間層の重み勾配を考えると、

wij=Ewij=Eujujwij

このうち ujeij

ujwij=(lp=1ypwpj+bj)wij=yi()

もう一つ Euj 部分は連鎖律使って

Euj=Eyjyjuj

このうち yjuj は活性関数の微分
Eyj は中間層出力の勾配(= yj )。

δj=Euj=yjyjuj

と、なるほど。中間層出力の勾配を受け取って、中間層が計算できる → 下層からの結果を受け取る → 逆伝播の原理 というわけですね

ujwij=yi

を使うと wij=yiδj

バイアス勾配 bj

bj=Ebj=Eujujbj

このうち

ujbj=(lp=1ypwpj+bj)bj=1()

偏微分なので、bj 以外の全ての項は 0 になり、残った bj微分されるため。

そのため

bj=δj

こんなノリで上の層に伝播を続けていく。

ニューラルネットに実際に適用して考える

損失関数を二乗誤差、活性関数を恒等関数とした出力層の場合

δk=Euk=Eykykuk

を求めたい。
二乗誤差関数は

12k(yktk)2

なので、

Eyk=yk(12k(yktk)2)=yk(12(y0t0)2+12(y1t1)2+...12(yntn)2)=yktk()

偏微分につき、yk を含まない項はすべて 0 となるため。

隣の項は 活性関数が恒等関数(入力 = 出力)の関数なので、

ykuk=1

なので

δk=(yktk)1=yktk

中間層がシグモイド関数であるとした場合

シグモイド関数 f(x)=11+exp(x)微分f(x)=(1f(x))f(x) なので

ykuj=(1yj)yjδj=yjykuj=yj(1yj)yj

交差エントロピー誤差とソフトマックス関数を採用した出力層

これは分類で使う組み合わせ。

交差エントロピー誤差

こちらの記事 が丁寧に書いてくれています。

数式的にはこんな関数

E(t,y)=xtxlog(yx)

理想の出力 t は、スイッチ出力なので [1, 0, 0, 0] の様な一つだけ 1 となってるので、仮に入力値を [0.5, 0.8, 0.2, 0.4] だとすると

E(t,y)=(1log(0.5)+0log(0.8)+0log(0.2)+0log(0.4)=log(0.5)=(0.693...)=0.693

こんな感じで出てくる。

ソフトマックス関数

過去記事 参照(グラフ付き)。

y=exp(x)nk=1exp(k)

こんな感じの関数を組み合わせすると

E=ktklog(exp(uk)kexp(uk))

この時、 logpq=log(p)log(q) を使うと

=k(tklog(exp(uk))tklog(kexp(uk)))==ktkuk+log(kexp(uk))

これを更新式に突っ込むと

δk=Eykykuk=uk(ktkuk+log(kexp(uk)))=tk+exp(uk)kexp(uk)=tk+yk

ようやっと…微分オワタ (;'∀')