読者です 読者をやめる 読者になる 読者になる

謎言語使いの徒然

適当に気になった技術や言語を流すブログ。

ランダムフォレストを実装する

アルゴリズム 勉強 勉強会 Scala

ランダムフォレスト?

実装してみた感じ、教師ありデータから分類・学習し、その後に与えられた未知のデータに対して、識別・分類を行うアルゴリズム

正答率は 7-8 割位が目安。余り複雑な学習はできない。他のアルゴリズムと組み合わせて使うといい感じになるらしい。

表面的な特徴で言えば、正答率は7-8割程度ではあるが、比較的処理が軽いのと、たまにある例外的なデータノイズに強いところだ。

アルゴリズム

数式やらなにやらは以下のスライドに任せる。

機会学習ハッカソン:ランダムフォレスト

複数の学習器にランダム抽出のデータを食わせて学習させ、データを与えられた際にはその学習器の投票によってその結果を識別する。

そのため、過半数未満の学習器が、学習データのデータノイズで妙な学習をしてしまっても、ある程度補正が効くというのが利点。

勉強会の雰囲気

勉強会で実装出来たのは、動物の識別だけだった。動物の識別パターンは全て二値だったのでやりやすかったとも言える。

2時間で理解して組むって結構鬼仕様だったとだけ言っておく。実装出来た人の方が少ない雰囲気だった。

例に酔って Scala 実装だが、こんな感じだ。

実装(動物の識別)

package animals

import scala.io.Source

case class Line(field: List[String], result: String)

object Resource {
  def read() = {
    val file = Source.fromFile("/path/to/animals.dat", "utf-8")
    file.getLines().filterNot(_.isEmpty).map(s => {
      val splits = s.split("\t")
      Line(
        List(
          splits(0),
          splits(1),
          splits(2)
        ),
        splits(3)
      )
    }).toList
  }

  /**
   * 2/3 をランダムに返す
   * @param data
   * @return
   */
  def dataSets(data: List[Line]):List[Line] = {
    import scala.util.Random
    Random.shuffle(data).take(3 /* (data.size / 3) * 2 */) // データセットが今回小さすぎるので
  }

  /**
   * 初期エントロピー計算用
   * @param data
   */
  def defaultEnt(data: List[Line]): Double = {
    val allSize = data.size
    val groups  = data.groupBy(_.result)

    def groupEnt(set: Int) = {
      val aq = set.toDouble / allSize.toDouble
      aq * Math.log(aq)
    }

    groups.map { case (n, l) => {
      - groupEnt(l.size)
    } } sum
  }
}

/**
 * 親エントロピーと現在のセットを食わせる
 */
case class Node(parent: Double, fields:List[Int], dataSet: List[Line]) {

  val currentSize = dataSet.size

  /**
   * 最大のエントロピーを持つ質問IDとその時のエントロピー
   */
  val maxEntField = {
    // 指定フィールドと正解を与えて最も分類出来ているであろう数値を返す。
    val allEnt = fields.map(i => {
      // このインデックスで分類した時のエントロピーは?
      // キーで2セット作る
      val group = dataSet.groupBy(_.field(i)).map(_._2.size).toSeq
      (i, {
        parent + {
          group.map(_.toDouble / currentSize /* P */).map(p => p * Math.log(p) /* I */).sum
        }
      })
    })
    // 比較した内容をエントロピーでソートして最後のやつが max
    allEnt.sortBy(_._2).last
  }

  val question = maxEntField._1
  val currentEnt = maxEntField._2

  val childNodes:Option[List[Node]] = {
    val usableNext = fields.filterNot(_ == question)
    val group = dataSet.groupBy(_.field(question)).map(_._2).toList.filterNot(_.isEmpty)
    if (usableNext.isEmpty) None else {
      Some(
        group.map(l => Node(currentEnt, usableNext, l))
      )
    }
  }

  /**
   * この解析ツリーにデータを投げて分類を頼む
   * @param check
   */
  def whats(check: List[String]):String = {
    // リーフなら結果を返す
    if (fields.isEmpty || childNodes.isEmpty) dataSet.head.result
    else {
      // ノードなら自分の判定基準に合わせて子のリーフに投げる
      val grp = check(question)
      def testHead = childNodes.get.head.whats(check)
      if (grp == childNodes.get.head.dataSet.head.field(question)) testHead
      else {
        if (childNodes.get.size == 1) testHead
        else childNodes.get.apply(1).whats(check)
      }
    }
  }

  /**
   * 表示してみようか
   */
  override def toString() = s"Node($parent, $dataSet, $question, $currentEnt, $childNodes)"
}

object RandomForest extends App {
  // データソース
  val all = Resource.read()

  // 初期エントロピー
  val I = Resource.defaultEnt(all)

  // 質問に使えるインデックス
  val indexes = 0 to (all.head.field.size - 1)

  // 解析器(適当に 15 個でも作る?)
  val roots = (1 to 15).map(i => Node(I, indexes.toList, Resource.dataSets(all)))

  // 適当にほ乳類でも投げてみましょうか?
  val test = List("草食", "胎生", "恒温")
  val results = roots.map(n => n.whats(test))

  println(s"Animal: $test : $results")
}

読みづらい。

リファクタなんぞしている暇がなかったのだから仕方無い。

アヤメの種類識別

データ

テストデータはアヤメの葉っぱのサイズ。 これは 第三回機械学習アルゴリズム実装会 - connpassで紹介されたデータだ。

https://github.com/watanabetanaka/randomForest/

アヤメの種類は、葉や茎、花のサイズで識別出来るという理屈らしい。

package iris

import scala.io.Source

case class Line(id: Long, field: List[Double], result: String)

object Resource {

  /**
   * ファイルを開く
   * @return
   */
  def read() = {
    val resource = this.getClass.getResource("../iris.dat")
    val file = Source.fromFile(resource.getPath, "utf-8")

    file.getLines().filterNot(_.isEmpty).map(s => {
      val splits = s.split("\t")
      Line(
        splits(0).toInt,
        List(
          splits(1).toDouble,
          splits(2).toDouble,
          splits(3).toDouble,
          splits(4).toDouble
        ),
        splits(5)
      )
    }).toList
  }
}

object RandomForest {

  /**
   * 2/3 をランダムに返す
   * @param data
   * @return
   */
  def dataSets(data: List[Line]):List[Line] = {
    import scala.util.Random
    Random.shuffle(data).take((data.size / 3) * 2)
  }

  /**
   * 初期エントロピー計算用
   * @param data
   */
  def defaultEnt(data: List[Line]): Double = {
    val allSize = data.size
    val groups  = data.groupBy(_.result)

    def groupEnt(set: Int) = {
      val aq = set.toDouble / allSize.toDouble
      aq * Math.log(aq)
    }

    groups.map { case (n, l) => {
      - groupEnt(l.size)
    } } sum
  }

  trait Node {
    val result: String
    val isLeaf: Boolean
    def dispTree(s: String): Unit
    def test(dataSet: List[Double]): String = result
  }

  case class Branch(index:Int, entropy:Double, threshHold:Double, left: Option[Node], right:Option[Node], result: String) extends Node {
    val isLeaf:Boolean = false

    def dispTree(s: String) {
      println(s"$s ${this.toString}")
      val next = s"$s  "
      left.foreach(_.dispTree(next))
      right.foreach(_.dispTree(next))
    }

    override def test(dataSet: List[Double]) = {
      val border = dataSet(index) < threshHold
      if (border) right.map(_.test(dataSet)).getOrElse(left.map(_.test(dataSet)).getOrElse("ERROR"))
      else left.map(_.test(dataSet)).getOrElse("ERROR")
    }
  }

  case class Leaf(result: String) extends Node {
    val isLeaf:Boolean = true

    def dispTree(s: String) {
      println(s"$s ${this.toString}")
    }
  }

  def createNode(i: Double, indexes: List[Int], dataSets: List[Line]):Node = {
    /**
     * 終了条件:末端まで到達した
     * - 比較出来る質問がなくなった
     * - 全データが同じ物だと断言出来る
     * @return
     */
    def isLeaf() = indexes.isEmpty || dataSets.groupBy(_.result).size == 1

    /**
     * 現在のデータサイズ
     * @return
     */
    def dataSize() = dataSets.size

    /**
     * 現在のノードの多数派結論
     * @return
     */
    def result() = dataSets.map(s => s.result -> s.result).groupBy(_._1).map(s => s._1 -> s._2.size).toList.sortBy(_._2).last._1

    /**
     * P(x) の計算
     * @param size
     * @return
     */
    def P(size: Double) = size / dataSize.toDouble

    /**
     * 個々のエントロピー計算
     * @param size
     * @return
     */
    def I(size: Double) = P(size) * Math.log(P(size))

    // ここが終端なら、さっさと結果だけ返す。
    if (isLeaf) {
      Leaf(result())
    } else {
      // ここは枝なので、小要素を作る
      // どのインデックスで分離するのが一番エントロピーが大きい?
      val ents = indexes.map(index => {
        // オーダー
        val pair = dataSets.map(s => s.field(index) -> s)

        // しきい値
        val threshHold = pair.map(_._1).sum / dataSize.toDouble

        // しきい値でデータを分割
        val data  = dataSets.map(l => (l.field(index) < threshHold) -> l)
        val left  = data.filter(_._1 == false).map(_._2)
        val right = data.filter(_._1).map(_._2)

        // エントロピーの計算
        val entropy = i + (I(right.size) + I(left.size))

        if (left.isEmpty || right.isEmpty) {
          println("====================================")
          println("WARNING: Left or Right is Empty")
          println(s"DataSet: $dataSets")
          println(s"Index: $index, ThreshHold: $threshHold")
          println(s"Left: ${left.isEmpty}, ${right.isEmpty}")
          println("====================================")
        }

        (index, threshHold, entropy, left, right)
      }).sortBy(_._3).last

      val left  = ents._4
      val right = ents._5

      def makeNode(subset: List[Line]):Option[Node] = {
        val subIndexes = indexes.filterNot(_ == ents._1)
        if (subset.isEmpty) None
        else Some(createNode(ents._3, subIndexes, subset))
      }

      Branch(ents._1, ents._3, ents._2, makeNode(left), makeNode(right), result())
    }
  }
}

object IrisForest extends App {
  // データソース
  val all = Resource.read()

  // 初期エントロピー
  val I = RandomForest.defaultEnt(all)
  println(s"初期エントロピー : $I")

  // 質問に使えるインデックス
  val indexes = (0 to (all.head.field.size - 1)).toList

  // 学習器の作成
  val mls = (1 to 20).map(_ => RandomForest.createNode(I, indexes, RandomForest.dataSets(all))).toList

  mls.foreach(_.dispTree(""))

  // 動かしてみる
  // 87    6.7 3.1 4.7 1.5 versicolor
  val data = List(6.7, 3.1, 4.7, 1.5)
  println(s"TEST: 87, 6.7, 3.1, 4.7, 1.5 versicolor, Results: ${mls.map(_.test(data))}")
}