最適化アルゴリズム
と言っても機械学習的な意味で。
単純に勾配降下法を適用すると、局所解に捕まる問題は先に述べた通り。
分かりやすくサンプルを考えてみると
import numpy as np import matplotlib.pyplot as plt def func(x): return x*x*x*x + 2*x*x*x + -38*x*x + 2*x X = np.arange(-10, 10, 0.02) Y = func(X) plt.plot(X, Y) plt.show()
微分済みの関数は
def delta_func(x): return 4*x*x*x + 6*x*x -76*x + 2 # ここに 7,5 を突っ込むと delta_func(7.5) # 1457.0
傾斜凄いので、学習バイアスを仮に 0.001 位で想定すると
6.043 5.398451459972 5.002559811688421 4.729832392627848 4.529821496454675 4.377177480355534 4.257423546210319 4.161559438372887 4.08363734432107 4.019540738521101 3.9663154081729166 3.921778562920445 3.8842783889882955 3.8525397914062047 3.8255618877089477 中略 3.658931080932697 3.658743374723007 3.6585797983577586 3.658437247050682 3.658313016197765 3.658204749631204 3.6581103946042393 3.6581103946042393
とまぁこんな感じで、明らかに x=-5....
の方が正解にもかかわらず途中の 3.7
近辺に捕まってしまう。
この局所解を回避するのが最適化アルゴリズムと言ってるご様子。
確率的勾配降下法
訓練用データの中からパラメータの更新毎に、ランダムなサンプルを選び出すことで、局所解に捕まりにくくする。
要するに開始位置をランダムに指定するだけだ。
運よく 0 以下で始まれば最も値の小さい箇所に行ける。
Momentum
確率的勾配降下法に、慣性項を設けたというもの。
前回の更新量に追加でいくらかの値を設定する。
alpha = 0.001 beta = 2.7 current = 7.5 old = 0 last_w = 0 while abs(current - old) >= 0.0001: val = delta_func(current) old = current new_w = - val * alpha current = current + new_w + last_w * beta last_w = new_w print(current)
6.043 1.4645514599719993 -0.19185848304183395 0.01781967384422567 -0.028117364030085537 -0.034007546495780354 -0.049781152307481306 -0.06797555468750062 -0.0908220708732037 -0.11919109008260265 -0.15449030499302088 中略 -5.191354001667161 -5.180269368221342 -5.1771810299574925 -5.180181249424525 -5.184162051676043 -5.185922174066375 -5.185379292975573 -5.184060245776873 -5.183254760318419 -5.183254259207195
見ての通り、下がるときに慣性を利用して追加で下がるので、パラメータさえ合ってれば局所解を乗り越えていく。
AdaGrad
2011 年に現れたアルゴリズムで、学習が進むたびに学習係数を減らそうという試み。
import math h = 0 current = 7.5 old = 0 while abs(current - old) >= 0.0001: val = delta_func(current) h = h + val * val new_w = - val / math.sqrt(h) old = current current = current + new_w print(current) print(current)
見ればわかるが、ガンガン変化量を削られるw
6.5 5.991688924667466 5.646410961134836 5.386994732474357 5.181464168891868 5.013125259836374 4.87204852326966 4.7518067607144445 4.647984515962909 4.557410990822424 4.47772972020331 4.407140798818975 4.344238282940154 4.287903250227089 4.23723099543947 4.1914800066613 4.150035310462555 4.1123815712304745 4.078082977842838 4.04676795652729 4.018117381523742 3.991855364159257 3.9677419716725204 3.945567410183514 3.925147332371601 3.906319018896746 3.8889382456169352 3.8728766941891757 3.8580197969761705 3.8442649318798843 3.8315199012281496 3.8197016428478148 3.808735132162399 3.798552442405527 3.789091936457758 3.780297568841414 3.772118280375192 3.764507471142559 3.757422539948691 3.7508244804687685 3.744677525931404 3.738948835515955 3.733608216734055 3.728627878962639 3.723982214036123 3.719647600419263 3.7156022279933074 3.711825940915271 3.7083000963686423 3.705007437325738 3.7019319776970567 3.6990588984593322 3.6963744535380294 3.6938658843770793 3.6915213422630684 3.689329817586468 3.6872810753218084 3.6853655960944005 3.68357452227539 3.6818996086112543 3.680333176949806 3.6788680746735536 3.677497636493863 3.6762156492967013 3.6750163197634835 3.6738942445193583 3.672844382586653 3.6718620299436235 3.670942796008503 3.6700825818864313 3.6692775602324965 3.668524156598014 3.6678190321395934 3.6671590675816272 3.6665413483327454 3.665963150665686 3.665421928878012 3.664915303358293 3.66444104948883 3.6639970873218672 3.663581471971482 3.663192384668147 3.662828124427272 3.6624871002869765 3.662167824073909 3.661868903659179 3.6615890366694464 3.6613270046208934 3.6610816674463065 3.6608519583877244 3.660636879229203 3.6604354958461323 3.660246934049288 3.660070375703399 3.6599050551014836 3.6597502555775594 3.659605306341589 3.659469579521657 3.6593424873994564 3.6592234798261245 3.6591120418063894 3.6590076912398155 3.6589099768087188 3.6589099768087188
この h が毎回デカくなるから、学習もその都度抑えられるという仕組みだが、弱点として途中で h
がデカくなりすぎて更新が止まることがある点。
今回は見事にそれ。
RMSProp
論文は存在してない?
を仕込むことで、以前の h をある程度忘れるという式ですね。