Repeated Cross-Validation(RCV)はCVを反復して安定性を高めようという方法である。
ただ、何回繰り返せばよいかわからない。
説明を読んでいると5回から10回くらいであろうと書いてあるが、その根拠は経験則のようだ。 社会調査で使うデータの反復計算は短時間で可能だが、データビッグデータでLassoを行った場合、計算時間と計算の安定性はトレードオフの関係になり、バランスを取ることが求められた結果、5回から10回という設定が落としどころなのだ。
実際にどのくらいの回数が適切なのかを調べるために、シミュレーションをしてみた。
RCVの反復回数は1回から30回までを実験することにした。回数が多くなれば安定性が高まるといわれているが、本当にそうなのかを検証する。
Lassoは乱数の影響を受けやすいため、乱数は5つの異なるシードを用意し、各回数ごとに5回推定を行い、統計量はその平均を取った。
データの読み込みと前処理
library(haven)
library(tidyr)
auto <- haven::read_dta("http://www.stata-press.com/data/r9/auto.dta")
auto <- auto %>% drop_na() # 全カラムに対してNAがない行を抽出
auto$foreign <- as.integer(auto$foreign)
independent.var <- auto[, c("price", "mpg", "rep78", "headroom", "trunk", "weight", "length", "turn", "displacement", "gear_ratio", "foreign")]
doParallelパッケージを用いて1コアではなく複数コアで計算するようにして、計算時間を短縮させるようにした。
library(doParallel) # 利用可能なコア数を検出 システム全体のコア数から2つを引く num_cores <- detectCores() - 2 # クラスタを作成し、コア数を指定 cl <- makeCluster(num_cores) # 並列バックエンドを登録 registerDoParallel(cl)
1回から30回の反復計算
library(caret)
library(glmnet)
results <- list()
for (repeats in 1:30) {
predictions <- list()
for (seed in 1:5) {
set.seed(seed)
train_control <- trainControl(method = "repeatedcv",
number = 10, # 10-fold CV
repeats = repeats, # 繰り返しの回数
savePredictions = "final")
model <- train(price ~ ., data = independent.var,
method = "glmnet", # Lasso回帰
trControl = train_control,
tuneLength = 10) # チューニングするλの数
predictions[[seed]] <- model$pred
}
# すべての予測値を結合
all_preds <- do.call(rbind, predictions)
results[[paste("repeats", repeats, sep = "_")]] <- all_preds
}
統計量の計算
calculate_metrics <- function(predictions) {
actual <- as.numeric(predictions$obs)
predicted <- as.numeric(predictions$pred)
mse <- mean((actual - predicted)^2)
rmse <- sqrt(mse)
mae <- mean(abs(actual - predicted))
rsquared <- cor(actual, predicted)^2
# 95%信頼区間の計算
mse_se <- sd((actual - predicted)^2) / sqrt(length(actual))
mse_ci_lower <- mse - 1.96 * mse_se
mse_ci_upper <- mse + 1.96 * mse_se
rmse_se <- sd(sqrt((actual - predicted)^2)) / sqrt(length(actual))
rmse_ci_lower <- rmse - 1.96 * rmse_se
rmse_ci_upper <- rmse + 1.96 * rmse_se
mae_se <- sd(abs(actual - predicted)) / sqrt(length(actual))
mae_ci_lower <- mae - 1.96 * mae_se
mae_ci_upper <- mae + 1.96 * mae_se
list(MSE = mse, RMSE = rmse, MAE = mae, R2 = rsquared,
MSE_CI_Lower = mse_ci_lower, MSE_CI_Upper = mse_ci_upper,
RMSE_CI_Lower = rmse_ci_lower, RMSE_CI_Upper = rmse_ci_upper,
MAE_CI_Lower = mae_ci_lower, MAE_CI_Upper = mae_ci_upper)
}
metrics <- data.frame(repeats = integer(),
MSE = numeric(), RMSE = numeric(), MAE = numeric(), R2 = numeric(),
MSE_CI_Lower = numeric(), MSE_CI_Upper = numeric(),
RMSE_CI_Lower = numeric(), RMSE_CI_Upper = numeric(),
MAE_CI_Lower = numeric(), MAE_CI_Upper = numeric())
for (repeats in 1:30) {
preds <- results[[paste("repeats", repeats, sep = "_")]]
metrics_res <- calculate_metrics(preds)
metrics <- rbind(metrics, data.frame(repeats = repeats,
MSE = metrics_res$MSE,
RMSE = metrics_res$RMSE,
MAE = metrics_res$MAE,
R2 = metrics_res$R2,
MSE_CI_Lower = metrics_res$MSE_CI_Lower,
MSE_CI_Upper = metrics_res$MSE_CI_Upper,
RMSE_CI_Lower = metrics_res$RMSE_CI_Lower,
RMSE_CI_Upper = metrics_res$RMSE_CI_Upper,
MAE_CI_Lower = metrics_res$MAE_CI_Lower,
MAE_CI_Upper = metrics_res$MAE_CI_Upper))
}
# 並行処理の停止
stopCluster(cl)
結果の表示
print(metrics) write.csv(metrics, "metrics.csv")
| repeats | MSE | RMSE | MAE | R2 |
|---|---|---|---|---|
| 1 | 4640526 | 2154.188 | 1622.182 | 0.44507 |
| 2 | 4657895 | 2158.216 | 1621.973 | 0.442928 |
| 3 | 4691725 | 2166.039 | 1613.545 | 0.439773 |
| 4 | 4667408 | 2160.418 | 1614.21 | 0.442108 |
| 5 | 4653046 | 2157.092 | 1617.974 | 0.44343 |
| 6 | 4664083 | 2159.649 | 1621.68 | 0.442089 |
| 7 | 4671125 | 2161.278 | 1622.283 | 0.441242 |
| 8 | 4685531 | 2164.609 | 1625.766 | 0.439507 |
| 9 | 4679912 | 2163.311 | 1624.811 | 0.440183 |
| 10 | 4681543 | 2163.687 | 1625.365 | 0.439969 |
| 11 | 4676046 | 2162.417 | 1623.825 | 0.440633 |
| 12 | 4671917 | 2161.462 | 1622.463 | 0.441123 |
| 13 | 4663648 | 2159.548 | 1620.973 | 0.442115 |
| 14 | 4661294 | 2159.003 | 1621.043 | 0.442395 |
| 15 | 4643063 | 2154.777 | 1622.585 | 0.444758 |
| 16 | 4639505 | 2153.951 | 1621.533 | 0.445182 |
| 17 | 4638211 | 2153.651 | 1622.331 | 0.445373 |
| 18 | 4648235 | 2155.976 | 1618.812 | 0.443981 |
| 19 | 4646286 | 2155.524 | 1618.784 | 0.444212 |
| 20 | 4650611 | 2156.528 | 1617.275 | 0.443693 |
| 21 | 4654001 | 2157.313 | 1623.167 | 0.443283 |
| 22 | 4656675 | 2157.933 | 1620.82 | 0.442949 |
| 23 | 4651028 | 2156.624 | 1620.096 | 0.443641 |
| 24 | 4658572 | 2158.372 | 1620.891 | 0.442723 |
| 25 | 4660592 | 2158.84 | 1622.141 | 0.442491 |
| 26 | 4670758 | 2161.194 | 1621.974 | 0.441388 |
| 27 | 4662009 | 2159.169 | 1622.854 | 0.442308 |
| 28 | 4662111 | 2159.192 | 1620.957 | 0.442302 |
| 29 | 4657345 | 2158.088 | 1619.483 | 0.44291 |
| 30 | 4663181 | 2159.44 | 1619.394 | 0.442219 |
MSEのプロット
MSEとMSEの95%信頼区間のプロットをみてみよう。
library(ggplot2)
ggplot(metrics, aes(x = repeats, y = MSE)) +
geom_line(color = "blue") +
geom_point(color = "blue") +
geom_ribbon(aes(ymin = MSE_CI_Lower, ymax = MSE_CI_Upper), alpha = 0.2, fill = "gray") +
ylab("MSE") +
xlab("Number of Repeats") +
ggtitle("MSE over Repeats with 95% CI") +
theme_minimal()

RMSEのプロット
次にRMSEとRMSEの95%信頼区間のプロットをみてみよう。
ggplot(metrics, aes(x = repeats, y = RMSE)) +
geom_line(color = "red") +
geom_point(color = "red") +
geom_ribbon(aes(ymin = RMSE_CI_Lower, ymax = RMSE_CI_Upper), alpha = 0.2, fill = "gray") +
ylab("RMSE") +
xlab("Number of Repeats") +
ggtitle("RMSE over Repeats with 95% CI") +
theme_minimal()

MAEのプロット
MAEとMAEの95%信頼区間のプロットをみてみよう。
ggplot(metrics, aes(x = repeats, y = MAE)) +
geom_line(color = "green") +
geom_point(color = "green") +
geom_ribbon(aes(ymin = MAE_CI_Lower, ymax = MAE_CI_Upper), alpha = 0.2, fill = "gray") +
ylab("MAE") +
xlab("Number of Repeats") +
ggtitle("MAE over Repeats with 95% CI") +
theme_minimal()

R-squaredの結果をまとめたグラフ
一応、R二乗値もみよう。
# R-squaredのプロット
ggplot(metrics, aes(x = repeats, y = R2)) +
geom_line(color = "purple") +
geom_point(color = "purple") +
ylab("R-squared") +
xlab("Number of Repeats") +
ggtitle("R-squared over Repeats") +
theme_minimal()

R二乗値はあまりどうでもいいので無視をするべきだろう。
小括
MSE、RMSE、MAEは反復回数が多くなるにしたがって95%区間が狭まるため、やはり反復回数を多くすれば安定するというのは正しいようだ。
各統計量をみると、MSE、RMSEが最も少なかったのは17回目で、それぞれ4638211.3、2153.65であった。17回の次は、16回、1回、15回、19回、18回となった。されほど大差はないようにも見えるが、反復回数は15-19のあいだに集中している。これはこのデータでに関しては、ということではあるが。
MAEは一桁を無視すると、20、19、18、30、29、23回という順番でMSE、RMSEと同じ傾向とはいかないようだ。
統計学的な根拠をもって、回数を決めるというのはできないようだ。
安定というのは、統計量の95%信頼区間が狭いことでもあるので、95%信頼区間から変動点の検出を行ってみる。
RMSEの変動点の検出
# 分析用のデータを作成 RMSE.dat <- subset(metrics, select = c(repeats, RMSE, RMSE_CI_Lower, RMSE_CI_Upper)) # 必要なライブラリをロード library(tidyverse) library(changepoint) # 95%信頼区間の幅を計算 RMSE.dat$CI_width <- RMSE.dat$RMSE_CI_Upper - RMSE.dat$RMSE_CI_Lower # 変動点検出 cpt <- cpt.mean(RMSE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(RMSE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt)
結果
1 2 3 4 5 6 7 8 9 10 12 14 16 19 22 26

RMSEのそのまま値だと変動点を見出すのは困難なようだ。
RMSEの平均と分散の変動点を検出
cpt_meanvar <- cpt.meanvar(RMSE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(RMSE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt_meanvar), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt_meanvar)
結果。
6 14 22

MSEの平均と分散の変動点を検出
# 分析用のデータを作成 MSE.dat <- subset(metrics, select = c(repeats, MSE, MSE_CI_Lower, MSE_CI_Upper)) # 95%信頼区間の幅を計算 MSE.dat$CI_width <- MSE.dat$MSE_CI_Upper - MSE.dat$MSE_CI_Lower # 平均と分散の変動点を検出 cpt_meanvar.MSE <- cpt.meanvar(MSE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(MSE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt_meanvar.MSE), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt_meanvar.MSE)
結果。
6 14 19 24 26

MAEの平均と分散の変動点を検出
# 分析用のデータを作成 MAE.dat <- subset(metrics, select = c(repeats, MAE, MAE_CI_Lower, MAE_CI_Upper)) # 95%信頼区間の幅を計算 MAE.dat$CI_width <- MAE.dat$MAE_CI_Upper - MAE.dat$MAE_CI_Lower # 平均と分散の変動点を検出 cpt_meanvar.MAE <- cpt.meanvar(MAE.dat$CI_width, method = "PELT") # 結果のプロット ggplot(MAE.dat, aes(x = repeats, y = CI_width)) + geom_line() + geom_point() + geom_vline(xintercept = cpts(cpt_meanvar.MAE), linetype = "dashed", color = "red") + labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") + theme_minimal() # 検出された変動点の表示 cpts(cpt_meanvar.MAE)
結果。
6 14 22

結論
計算して気づいたがRMSEとMAEの95%信頼区間幅同じになるようだ。そのことを考慮に入れて、RMSEとMSEとMAEを見比べると6回と14回がいずれも変動点と検出されているので、おそらく6回か14回というのがこの分析が進める回数といえるだろう。
6回というのは5から10回という経験則を示したものであり、データによって違うことを考慮すれば、10回に近い方が安全であろう。
今回のようなサンプルサイズが小さなデータであれば、14回が適切な回数となるだろう。
この回数はこのデータで計算した時の数字である。どのデータにも当てはまるものではない。
とはいえ、データごとに変動点は算出できるので、データごとに判断することは可能である。