Rについて
をベースに学んでいく。
今回は集団学習(PP.197-205)を扱う。
16. 集団学習
集団学習(アンサンブル学習)は、決して精度が高いわけではない複数の結果を統合・組み合わせ、精度を向上させる機械学習方法である。複数の結果を統合し組み合わせる方法としては、判別・分類問題では多数決、回帰問題では平均が多く用いられている。アンサンブル学習では、異なるウェイトもしくは異なる標本から単純なモデルを複数作成し、これらを何らかの方法で組み合わせることで精度と汎用性を両立するモデルを構築する。
16.1 バギング
バギング(bootstrap aggregating)は1996年にによって提案された。
バギングでは、与えられたデータセットから、ブートストラップ法で複数の学習データを作成し、そのデータを用いて作成した回帰・分類結果を統合・組み合わせることで精度を向上させる。
(1) | 教師データから復元抽出方式で抽出を |
|
(2) | ステップ(1)を |
|
(3-1) | 回帰問題では、 |
|
(3-2) | 判別問題では多数決を取り、 |
################ ### バギング ### ################ library("mlbench") data(BreastCancer) x <- na.omit(BreastCancer) t <- floor(nrow(x) * 0.5) even_n <- 2 * (1:t) BC_train <- BreastCancer[even_n,-1] BC_test <- BreastCancer[-even_n,-1] for(i in 1:9){ BC_train[,i] <- as.integer(BC_train[,i]) BC_test[,i] <- as.integer(BC_test[,i]) } library("adabag") set.seed(20) BC_ba <- bagging(Class~., data = BC_train) BC_bap <- predict(BC_ba,BC_test) (tb_ba <- table(BC_test[,10], BC_bap$class)) 1 - sum(diag(tb_ba)/sum(tb_ba))
16.2 ブースティング
ブースティングは与えた教師付きデータを用いて学習し、その学習結果を踏まえて逐次重みの調整を繰り返すことで複数の学習結果を求める。その結果を統合・組み合わせ、精度を向上させる。
ブースティングの中で最も広く知られているのは、である。こうしたアルゴリズムに共通する大まかな流れは以下の通りである:
(1) | 重みの初期値 |
|
(2) | 学習を行い重みの更新を繰り返す( |
|
(a)重み |
||
(b)誤り率 |
||
(c)結果の信頼度 |
||
(d)重みを更新する。 |
||
(3) | 重み付き多数決で結果を出力する。2値判別の場合は |
###################### ### ブースティング ### ###################### set.seed(20) BC_ad <- boosting(Class~., data = BC_train) BC_adp <- predict(BC_ad, newdata = BC_test) (res <- BC_adp$confusion) 1 - sum(diag(res))/sum(res)
16.3 ランダムフォレスト
ランダムフォレストは。バギングの提案者であるにより提案された新しいデータ解析の手法である。バギングではすべての変数を用いる一方で、ランダムフォレストでは変数をランダムサンプリングしてから用いるため、多次元データ解析に向いている。ランダムサンプリングする変数の数
を自由に選択でき、
は全変数の数の正の平方根を取ることを勧めている。
(1) | 与えられたデータセットから |
|
(2) | 各々のブートストラップ標本データを用いてみ剪定の最大の決定木・回帰木を生成する。ただし分岐のノードはランダムサンプリングされた変数の中から最善のものを用いる。 | |
(3) | すべての結果を統合・組み合わせ、新しい予測・分類器を構築する。 |
########################## ### ランダムフォレスト ### ########################## library("randomForest") set.seed(20) BC_rf <- randomForest(Class~., data = BC_train, na.action = "na.omit") print(BC_rf) summary(BC_rf) BC_rf$type plot(BC_rf) # 木の数と誤判別率の関係 varImpPlot(BC_rf) # ジニ分散指標 BC_rfp <- predict(BC_rf, BC_test[,-10]) (BC_rft <- table(BC_test[,10], BC_rfp)) 1 - sum(diag(BC_rft))/sum(BC_rft)