技術をかじる猫

適当に気になった技術や言語、思ったこと考えた事など。

最適化アルゴリズム

と言っても機械学習的な意味で。

単純に勾配降下法を適用すると、局所解に捕まる問題は先に述べた通り。
分かりやすくサンプルを考えてみると

import numpy as np
import matplotlib.pyplot as plt


def func(x):
    return x*x*x*x + 2*x*x*x + -38*x*x + 2*x


X = np.arange(-10, 10, 0.02)
Y = func(X)

plt.plot(X, Y)
plt.show()

f:id:white-azalea:20210303203211p:plain

微分済みの関数は

def delta_func(x):
    return 4*x*x*x + 6*x*x -76*x + 2


# ここに 7,5 を突っ込むと
delta_func(7.5)  # 1457.0

傾斜凄いので、学習バイアスを仮に 0.001 位で想定すると

6.043
5.398451459972
5.002559811688421
4.729832392627848
4.529821496454675
4.377177480355534
4.257423546210319
4.161559438372887
4.08363734432107
4.019540738521101
3.9663154081729166
3.921778562920445
3.8842783889882955
3.8525397914062047
3.8255618877089477
中略
3.658931080932697
3.658743374723007
3.6585797983577586
3.658437247050682
3.658313016197765
3.658204749631204
3.6581103946042393
3.6581103946042393

とまぁこんな感じで、明らかに x=-5.... の方が正解にもかかわらず途中の 3.7 近辺に捕まってしまう。
この局所解を回避するのが最適化アルゴリズムと言ってるご様子。

確率的勾配降下法

訓練用データの中からパラメータの更新毎に、ランダムなサンプルを選び出すことで、局所解に捕まりにくくする。
要するに開始位置をランダムに指定するだけだ。

運よく 0 以下で始まれば最も値の小さい箇所に行ける。

Momentum

確率的勾配降下法に、慣性項を設けたというもの。
前回の更新量に追加でいくらかの値を設定する。

alpha = 0.001
beta = 2.7

current = 7.5
old = 0
last_w = 0

while abs(current - old) >= 0.0001:
    val = delta_func(current)
    old = current
    new_w = - val * alpha
    current = current + new_w + last_w * beta
    last_w = new_w
    print(current)
6.043
1.4645514599719993
-0.19185848304183395
0.01781967384422567
-0.028117364030085537
-0.034007546495780354
-0.049781152307481306
-0.06797555468750062
-0.0908220708732037
-0.11919109008260265
-0.15449030499302088
中略
-5.191354001667161
-5.180269368221342
-5.1771810299574925
-5.180181249424525
-5.184162051676043
-5.185922174066375
-5.185379292975573
-5.184060245776873
-5.183254760318419
-5.183254259207195

見ての通り、下がるときに慣性を利用して追加で下がるので、パラメータさえ合ってれば局所解を乗り越えていく。

AdaGrad

2011 年に現れたアルゴリズムで、学習が進むたびに学習係数を減らそうという試み。


h \leftarrow  h + (\frac{\partial E}{\partial w})^2 \\
w \leftarrow w - \eta \frac{1}{\sqrt{h}}\frac{\partial E}{\partial w}
import math

h = 0
current = 7.5
old = 0

while abs(current - old) >= 0.0001:
    val = delta_func(current)
    h = h + val * val
    new_w = - val / math.sqrt(h)
    old = current
    current = current + new_w
    print(current)

print(current)

見ればわかるが、ガンガン変化量を削られるw

6.5
5.991688924667466
5.646410961134836
5.386994732474357
5.181464168891868
5.013125259836374
4.87204852326966
4.7518067607144445
4.647984515962909
4.557410990822424
4.47772972020331
4.407140798818975
4.344238282940154
4.287903250227089
4.23723099543947
4.1914800066613
4.150035310462555
4.1123815712304745
4.078082977842838
4.04676795652729
4.018117381523742
3.991855364159257
3.9677419716725204
3.945567410183514
3.925147332371601
3.906319018896746
3.8889382456169352
3.8728766941891757
3.8580197969761705
3.8442649318798843
3.8315199012281496
3.8197016428478148
3.808735132162399
3.798552442405527
3.789091936457758
3.780297568841414
3.772118280375192
3.764507471142559
3.757422539948691
3.7508244804687685
3.744677525931404
3.738948835515955
3.733608216734055
3.728627878962639
3.723982214036123
3.719647600419263
3.7156022279933074
3.711825940915271
3.7083000963686423
3.705007437325738
3.7019319776970567
3.6990588984593322
3.6963744535380294
3.6938658843770793
3.6915213422630684
3.689329817586468
3.6872810753218084
3.6853655960944005
3.68357452227539
3.6818996086112543
3.680333176949806
3.6788680746735536
3.677497636493863
3.6762156492967013
3.6750163197634835
3.6738942445193583
3.672844382586653
3.6718620299436235
3.670942796008503
3.6700825818864313
3.6692775602324965
3.668524156598014
3.6678190321395934
3.6671590675816272
3.6665413483327454
3.665963150665686
3.665421928878012
3.664915303358293
3.66444104948883
3.6639970873218672
3.663581471971482
3.663192384668147
3.662828124427272
3.6624871002869765
3.662167824073909
3.661868903659179
3.6615890366694464
3.6613270046208934
3.6610816674463065
3.6608519583877244
3.660636879229203
3.6604354958461323
3.660246934049288
3.660070375703399
3.6599050551014836
3.6597502555775594
3.659605306341589
3.659469579521657
3.6593424873994564
3.6592234798261245
3.6591120418063894
3.6590076912398155
3.6589099768087188
3.6589099768087188

この h が毎回デカくなるから、学習もその都度抑えられるという仕組みだが、弱点として途中で h がデカくなりすぎて更新が止まることがある点。
今回は見事にそれ。

RMSProp

論文は存在してない?


h \leftarrow \rho h + (1 - \rho) (\frac{\partial E}{\partial w})^2 \\
w \leftarrow w - \eta \frac{1}{\sqrt{h}}\frac{\partial E}{\partial w}

 \rho を仕込むことで、以前の h をある程度忘れるという式ですね。