Repeated Cross-Validation(RCV)はCVを反復して安定性を高めようという方法である。
ただ、何回繰り返せばよいかわからない。
説明を読んでいると5回から10回くらいであろうと書いてあるが、その根拠は経験則のようだ。 社会調査で使うデータの反復計算は短時間で可能だが、データビッグデータでLassoを行った場合、計算時間と計算の安定性はトレードオフの関係になり、バランスを取ることが求められた結果、5回から10回という設定が落としどころなのだ。
実際にどのくらいの回数が適切なのかを調べるために、シミュレーションをしてみた。
RCVの反復回数は1回から30回までを実験することにした。回数が多くなれば安定性が高まるといわれているが、本当にそうなのかを検証する。
Lassoは乱数の影響を受けやすいため、乱数は5つの異なるシードを用意し、各回数ごとに5回推定を行い、統計量はその平均を取った。
データの読み込みと前処理
library(haven) library(tidyr) auto <- haven::read_dta("http://www.stata-press.com/data/r9/auto.dta") auto <- auto %>% drop_na() # 全カラムに対してNAがない行を抽出 auto$foreign <- as.integer(auto$foreign) independent.var <- auto[, c("price", "mpg", "rep78", "headroom", "trunk", "weight", "length", "turn", "displacement", "gear_ratio", "foreign")]
doParallelパッケージを用いて1コアではなく複数コアで計算するようにして、計算時間を短縮させるようにした。
library(doParallel) # 利用可能なコア数を検出 システム全体のコア数から2つを引く num_cores <- detectCores() - 2 # クラスタを作成し、コア数を指定 cl <- makeCluster(num_cores) # 並列バックエンドを登録 registerDoParallel(cl)
1回から30回の反復計算
library(caret) library(glmnet) results <- list() for (repeats in 1:30) { predictions <- list() for (seed in 1:5) { set.seed(seed) train_control <- trainControl(method = "repeatedcv", number = 10, # 10-fold CV repeats = repeats, # 繰り返しの回数 savePredictions = "final") model <- train(price ~ ., data = independent.var, method = "glmnet", # Lasso回帰 trControl = train_control, tuneLength = 10) # チューニングするλの数 predictions[[seed]] <- model$pred } # すべての予測値を結合 all_preds <- do.call(rbind, predictions) results[[paste("repeats", repeats, sep = "_")]] <- all_preds }
統計量の計算
calculate_metrics <- function(predictions) { actual <- as.numeric(predictions$obs) predicted <- as.numeric(predictions$pred) mse <- mean((actual - predicted)^2) rmse <- sqrt(mse) mae <- mean(abs(actual - predicted)) rsquared <- cor(actual, predicted)^2 # 95%信頼区間の計算 mse_se <- sd((actual - predicted)^2) / sqrt(length(actual)) mse_ci_lower <- mse - 1.96 * mse_se mse_ci_upper <- mse + 1.96 * mse_se rmse_se <- sd(sqrt((actual - predicted)^2)) / sqrt(length(actual)) rmse_ci_lower <- rmse - 1.96 * rmse_se rmse_ci_upper <- rmse + 1.96 * rmse_se mae_se <- sd(abs(actual - predicted)) / sqrt(length(actual)) mae_ci_lower <- mae - 1.96 * mae_se mae_ci_upper <- mae + 1.96 * mae_se list(MSE = mse, RMSE = rmse, MAE = mae, R2 = rsquared, MSE_CI_Lower = mse_ci_lower, MSE_CI_Upper = mse_ci_upper, RMSE_CI_Lower = rmse_ci_lower, RMSE_CI_Upper = rmse_ci_upper, MAE_CI_Lower = mae_ci_lower, MAE_CI_Upper = mae_ci_upper) } metrics <- data.frame(repeats = integer(), MSE = numeric(), RMSE = numeric(), MAE = numeric(), R2 = numeric(), MSE_CI_Lower = numeric(), MSE_CI_Upper = numeric(), RMSE_CI_Lower = numeric(), RMSE_CI_Upper = numeric(), MAE_CI_Lower = numeric(), MAE_CI_Upper = numeric()) for (repeats in 1:30) { preds <- results[[paste("repeats", repeats, sep = "_")]] metrics_res <- calculate_metrics(preds) metrics <- rbind(metrics, data.frame(repeats = repeats, MSE = metrics_res$MSE, RMSE = metrics_res$RMSE, MAE = metrics_res$MAE, R2 = metrics_res$R2, MSE_CI_Lower = metrics_res$MSE_CI_Lower, MSE_CI_Upper = metrics_res$MSE_CI_Upper, RMSE_CI_Lower = metrics_res$RMSE_CI_Lower, RMSE_CI_Upper = metrics_res$RMSE_CI_Upper, MAE_CI_Lower = metrics_res$MAE_CI_Lower, MAE_CI_Upper = metrics_res$MAE_CI_Upper)) } # 並行処理の停止 stopCluster(cl)
結果の表示
print(metrics) write.csv(metrics, "metrics.csv")
repeats | MSE | RMSE | MAE | R2 |
---|---|---|---|---|
1 | 4640526 | 2154.188 | 1622.182 | 0.44507 |
2 | 4657895 | 2158.216 | 1621.973 | 0.442928 |
3 | 4691725 | 2166.039 | 1613.545 | 0.439773 |
4 | 4667408 | 2160.418 | 1614.21 | 0.442108 |
5 | 4653046 | 2157.092 | 1617.974 | 0.44343 |
6 | 4664083 | 2159.649 | 1621.68 | 0.442089 |
7 | 4671125 | 2161.278 | 1622.283 | 0.441242 |
8 | 4685531 | 2164.609 | 1625.766 | 0.439507 |
9 | 4679912 | 2163.311 | 1624.811 | 0.440183 |
10 | 4681543 | 2163.687 | 1625.365 | 0.439969 |
11 | 4676046 | 2162.417 | 1623.825 | 0.440633 |
12 | 4671917 | 2161.462 | 1622.463 | 0.441123 |
13 | 4663648 | 2159.548 | 1620.973 | 0.442115 |
14 | 4661294 | 2159.003 | 1621.043 | 0.442395 |
15 | 4643063 | 2154.777 | 1622.585 | 0.444758 |
16 | 4639505 | 2153.951 | 1621.533 | 0.445182 |
17 | 4638211 | 2153.651 | 1622.331 | 0.445373 |
18 | 4648235 | 2155.976 | 1618.812 | 0.443981 |
19 | 4646286 | 2155.524 | 1618.784 | 0.444212 |
20 | 4650611 | 2156.528 | 1617.275 | 0.443693 |
21 | 4654001 | 2157.313 | 1623.167 | 0.443283 |
22 | 4656675 | 2157.933 | 1620.82 | 0.442949 |
23 | 4651028 | 2156.624 | 1620.096 | 0.443641 |
24 | 4658572 | 2158.372 | 1620.891 | 0.442723 |
25 | 4660592 | 2158.84 | 1622.141 | 0.442491 |
26 | 4670758 | 2161.194 | 1621.974 | 0.441388 |
27 | 4662009 | 2159.169 | 1622.854 | 0.442308 |
28 | 4662111 | 2159.192 | 1620.957 | 0.442302 |
29 | 4657345 | 2158.088 | 1619.483 | 0.44291 |
30 | 4663181 | 2159.44 | 1619.394 | 0.442219 |
MSEのプロット
MSEとMSEの95%信頼区間のプロットをみてみよう。
library(ggplot2) ggplot(metrics, aes(x = repeats, y = MSE)) + geom_line(color = "blue") + geom_point(color = "blue") + geom_ribbon(aes(ymin = MSE_CI_Lower, ymax = MSE_CI_Upper), alpha = 0.2, fill = "gray") + ylab("MSE") + xlab("Number of Repeats") + ggtitle("MSE over Repeats with 95% CI") + theme_minimal()
RMSEのプロット
次にRMSEとRMSEの95%信頼区間のプロットをみてみよう。
ggplot(metrics, aes(x = repeats, y = RMSE)) + geom_line(color = "red") + geom_point(color = "red") + geom_ribbon(aes(ymin = RMSE_CI_Lower, ymax = RMSE_CI_Upper), alpha = 0.2, fill = "gray") + ylab("RMSE") + xlab("Number of Repeats") + ggtitle("RMSE over Repeats with 95% CI") + theme_minimal()
MAEのプロット
MAEとMAEの95%信頼区間のプロットをみてみよう。
ggplot(metrics, aes(x = repeats, y = MAE)) + geom_line(color = "green") + geom_point(color = "green") + geom_ribbon(aes(ymin = MAE_CI_Lower, ymax = MAE_CI_Upper), alpha = 0.2, fill = "gray") + ylab("MAE") + xlab("Number of Repeats") + ggtitle("MAE over Repeats with 95% CI") + theme_minimal()
R-squaredの結果をまとめたグラフ
一応、R二乗値もみよう。
# R-squaredのプロット ggplot(metrics, aes(x = repeats, y = R2)) + geom_line(color = "purple") + geom_point(color = "purple") + ylab("R-squared") + xlab("Number of Repeats") + ggtitle("R-squared over Repeats") + theme_minimal()
R二乗値はあまりどうでもいいので無視をするべきだろう。
小括
MSE、RMSE、MAEは反復回数が多くなるにしたがって95%区間が狭まるため、やはり反復回数を多くすれば安定するというのは正しいようだ。
各統計量をみると、MSE、RMSEが最も少なかったのは17回目で、それぞれ4638211.3、2153.65であった。17回の次は、16回、1回、15回、19回、18回となった。されほど大差はないようにも見えるが、反復回数は15-19のあいだに集中している。これはこのデータでに関しては、ということではあるが。
MAEは一桁を無視すると、20、19、18、30、29、23回という順番でMSE、RMSEと同じ傾向とはいかないようだ。
統計学的な根拠をもって、回数を決めるというのはできないようだ。
安定というのは、統計量の95%信頼区間が狭いことでもあるので、95%信頼区間から変動点の検出を行ってみる。
RMSEの変動点の検出
# 分析用のデータを作成 RMSE.dat <- subset(metrics, select = c(repeats, RMSE, RMSE_CI_Lower, RMSE_CI_Upper)) # 必要なライブラリをロード library(tidyverse) library(changepoint) # 95%信頼区間の幅を計算 RMSE.dat$CI_width <- RMSE.dat$RMSE_CI_Upper - RMSE.dat$RMSE_CI_Lower # 変動点検出 cpt <- cpt.mean(RMSE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(RMSE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt)
結果
1 2 3 4 5 6 7 8 9 10 12 14 16 19 22 26
RMSEのそのまま値だと変動点を見出すのは困難なようだ。
RMSEの平均と分散の変動点を検出
cpt_meanvar <- cpt.meanvar(RMSE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(RMSE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt_meanvar), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt_meanvar)
結果。
6 14 22
MSEの平均と分散の変動点を検出
# 分析用のデータを作成 MSE.dat <- subset(metrics, select = c(repeats, MSE, MSE_CI_Lower, MSE_CI_Upper)) # 95%信頼区間の幅を計算 MSE.dat$CI_width <- MSE.dat$MSE_CI_Upper - MSE.dat$MSE_CI_Lower # 平均と分散の変動点を検出 cpt_meanvar.MSE <- cpt.meanvar(MSE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(MSE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt_meanvar.MSE), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt_meanvar.MSE)
結果。
6 14 19 24 26
MAEの平均と分散の変動点を検出
# 分析用のデータを作成 MAE.dat <- subset(metrics, select = c(repeats, MAE, MAE_CI_Lower, MAE_CI_Upper)) # 95%信頼区間の幅を計算 MAE.dat$CI_width <- MAE.dat$MAE_CI_Upper - MAE.dat$MAE_CI_Lower # 平均と分散の変動点を検出 cpt_meanvar.MAE <- cpt.meanvar(MAE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(MAE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt_meanvar.MAE), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt_meanvar.MAE)
結果。
6 14 22
結論
計算して気づいたがRMSEとMAEの95%信頼区間幅同じになるようだ。そのことを考慮に入れて、RMSEとMSEとMAEを見比べると6回と14回がいずれも変動点と検出されているので、おそらく6回か14回というのがこの分析が進める回数といえるだろう。
6回というのは5から10回という経験則を示したものであり、データによって違うことを考慮すれば、10回に近い方が安全であろう。
今回のようなサンプルサイズが小さなデータであれば、14回が適切な回数となるだろう。
この回数はこのデータで計算した時の数字である。どのデータにも当てはまるものではない。
とはいえ、データごとに変動点は算出できるので、データごとに判断することは可能である。