技術をかじる猫

適当に気になった技術や言語、思ったこと考えた事など。

JavaScript でニューラルネットを実行

前回のニューラルネットの分類で、アヤメの判別を実装しました。

white-azalea.hatenablog.jp

この学習データを JavaScript に持って行って動かそうとしたのが今回。

学習データを取り出そう

ニューロンの学習した重みとバイアスを JSON 化して取り出すようメソッド追加し、

import json

class Neuron:
    # 中略
    
    def show(self):
        print(json.dumps(self.w.tolist()))
        print(json.dumps(self.b.tolist()))

学習完了後にその値を取得します。

middle.show()
output.show()
[[-0.1645569757531888, 0.705461787072002, 0.13634643058513146, -0.001427639477009482, -0.37761246456078673, 0.32594566195486246, -0.17540383608743443, -0.637910226879948], [1.2405711145126166, 2.619979392767813, -0.8480757978236683, -0.34862875404728705, 0.6081766688709285, 2.1609719770773466, 1.1812526170586088, -0.10568548347423115], [-4.74027409643627, -8.512471290631712, 3.0535213927424927, 0.6853620335528933, -3.021058907040047, -7.302407221156781, -4.505304631691936, -1.458038332845281], [-2.6381391656274893, -5.0742956034421365, 1.6448946680125553, 0.321396808734011, -1.629813019140954, -4.235682116890767, -2.519697996486675, -0.7485783120615989]]
[1.3354703014752847, 2.856479211057539, -0.36041020389500916, -0.09619435431305065, 0.5270124844477969, 2.3500865314253825, 1.225318793483202, -0.5722749346649912]
[[3.3900843132512795, 0.3336196951711783, -3.707400474181074], [6.873749691931951, 1.4723979662987134, -8.337604254104868], [-4.032894687445366, 0.20291576621871618, 3.82277459252311], [-1.4497941475896428, 0.3517295106421541, 1.1299533286177883], [1.920858740500269, 0.21229798170387007, -2.151150281026468], [5.754690191389036, 0.9299055788952119, -6.690439208956392], [3.1882961310528275, 0.2901387295306705, -3.4911286154994707], [0.47133299264734746, 0.3382463002700566, -0.8153830951792407]]
[-6.786999325197024, 0.8398018579145571, 5.919115463040789]

JavaScript に適用する

そして利用したものは math.js

techc.omorita.com mathjs.org

これを読み込んで適用したコードが以下です

数学ライブラリですが大分癖が強い感じでした…
学習は Python でやってしまうので、ここでは順伝播 (forward) のみ実装しています。

<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Iris</title>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjs/9.2.0/math.js"></script>
</head>
<body>
    <script>
        const middle_w = math.matrix([[-0.37791011431611676, 0.03295488063509766, 0.3420945556236132, 0.30166856875056935, -0.21219877180909835, -0.010069964625783545, 0.27794499885092205, -0.398933181760742], [-1.9255089332415773, 1.2511527783210896, 1.7675887476888166, 1.6755548778237705, 0.7481479434697025, 1.1523410854237444, 1.5831249176747109, 0.4108994310001761], [6.086371767680688, -4.608346851803051, -6.010497555156929, -5.791896330967219, -3.32424587195477, -4.2963446303672965, -5.614285563371247, -2.509715984031604], [3.5367161799907016, -2.633132104633563, -3.505722035617003, -3.368791012895072, -1.8457100201853713, -2.4362656179086803, -3.2355526244884363, -1.351205989346281]]);
        const middle_b = math.matrix([[-1.5688415235721709, 1.3152896281324962, 1.9259653823017484, 1.833634432303389, 0.6867788300844402, 1.154169700725085, 1.7512061518314432, 0.16790344701525367]]);
        const output_w = math.matrix([[-7.0419433950389445, -0.2047617985662437, 7.226118617609221], [3.03586686340172, 0.6930753248897973, -3.7435620670805116], [4.160262110192553, 1.0942939867640822, -5.242799641786152], [3.9678402497242136, 1.0338072055208727, -5.000119717617293], [2.04249348347599, 0.46874914434725296, -2.503804422725612], [2.808673616002259, 0.6218796513771182, -3.4154934374365182], [3.8124202808756524, 0.9763499728790614, -4.790295541993576], [1.4054709195492832, 0.38402111025982255, -1.7693249864093947]]);
        const output_b = math.matrix([[-5.538366535062681, 0.9178471888963776, 4.588826644541712]]);

        function sigmoid(u) {
            let minus = math.map(u, v => -v);
            let tmp = math.exp(minus);
            return math.map(tmp, v => 1 / (1 + v));
        }

        function softmax(u) {
            let expU = math.exp(u);
            let tmp = math.sum(math.exp(u));
            return math.map(expU, v => v / tmp);
        }

        class Neuron {
            constructor(w, b, activation_function) {
                this.w = w;
                this.b = b;
                this.activation_function = activation_function;
            }

            forward(x) {
                let dot = math.multiply(x, this.w);
                let u = math.add(dot, this.b);
                return this.activation_function(u);
            }
        }

        const middle = new Neuron(middle_w, middle_b, sigmoid);
        const output = new Neuron(output_w, output_b, softmax);

        function detect(result) {
            let flatted = math.flatten(result);
            let a = math.subset(flatted, math.index(0));
            let b = math.subset(flatted, math.index(1));
            let c = math.subset(flatted, math.index(2));
            if (a > b && a > c) {
                return 'setosa';
            } else if (b > a && b > c) {
                return 'versicolor';
            } else {
                return 'virginica';
            }
        }

        function exec() {
            let sepal_length = Number(document.getElementById('sepal_length').value);
            let sepal_width = Number(document.getElementById('sepal_width').value);
            let petal_length = Number(document.getElementById('petal_length').value);
            let petal_width = Number(document.getElementById('petal_width').value);

            // ニューラルネットを実行
            const respond = output.forward(middle.forward(math.matrix(
                [[sepal_length / 10.0, sepal_width / 10.0, petal_length / 10.0, petal_width / 10.0]]
            )));
            document.getElementById('newralnet_response').innerText = respond;
            document.getElementById('result').innerText = detect(respond);
        }
    </script>

    <form action="#">
        <div><label for="sepal_length">がくの長さ</label><input id="sepal_length" type="text" name="sepal_length"></div>
        <div><label for="sepal_width">がくの幅</label>  <input id="sepal_width" type="text" name="sepal_width"></div>
        <div><label for="petal_length">花弁の長さ</label><input id="petal_length" type="text" name="petal_lengh"></div>
        <div><label for="petal_width">花弁の幅</label>  <input id="petal_width" type="text" name="petal_width"></div>
        <div id="newralnet_response"></div>
        <div id="result"></div>
        <div><button id="execute" type="button" onclick="exec();">検査実行</button></div>
    </form>
</body>
</html>

実行結果はこんな感じになります。

f:id:white-azalea:20210310232211p:plain

データはアヤメデータセットの値を設定しました。

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
df.loc[df['target'] == 0, 'target'] = "setosa"
df.loc[df['target'] == 1, 'target'] = "versicolor"
df.loc[df['target'] == 2, 'target'] = "virginica"
df

f:id:white-azalea:20210310232258p:plain