井出草平の研究ノート

反復交差検証 / Repeated Cross-Validationを用いたLasso、シミュレーションを行いベストな反復回数を求める[R]

Repeated Cross-Validation(RCV)はCVを反復して安定性を高めようという方法である。

ides.hatenablog.com

ただ、何回繰り返せばよいかわからない。
説明を読んでいると5回から10回くらいであろうと書いてあるが、その根拠は経験則のようだ。 社会調査で使うデータの反復計算は短時間で可能だが、データビッグデータでLassoを行った場合、計算時間と計算の安定性はトレードオフの関係になり、バランスを取ることが求められた結果、5回から10回という設定が落としどころなのだ。

実際にどのくらいの回数が適切なのかを調べるために、シミュレーションをしてみた。

RCVの反復回数は1回から30回までを実験することにした。回数が多くなれば安定性が高まるといわれているが、本当にそうなのかを検証する。
Lassoは乱数の影響を受けやすいため、乱数は5つの異なるシードを用意し、各回数ごとに5回推定を行い、統計量はその平均を取った。

データの読み込みと前処理

library(haven)
library(tidyr)

auto <- haven::read_dta("http://www.stata-press.com/data/r9/auto.dta")
auto <- auto %>% drop_na()  # 全カラムに対してNAがない行を抽出
auto$foreign <- as.integer(auto$foreign)

independent.var <- auto[, c("price", "mpg", "rep78", "headroom", "trunk", "weight", "length", "turn", "displacement", "gear_ratio", "foreign")]

doParallelパッケージを用いて1コアではなく複数コアで計算するようにして、計算時間を短縮させるようにした。

library(doParallel)

# 利用可能なコア数を検出 システム全体のコア数から2つを引く
num_cores <- detectCores() - 2
# クラスタを作成し、コア数を指定
cl <- makeCluster(num_cores)
# 並列バックエンドを登録
registerDoParallel(cl)

1回から30回の反復計算

library(caret)
library(glmnet)

results <- list()

for (repeats in 1:30) {
  predictions <- list()
  for (seed in 1:5) {
    set.seed(seed)
    train_control <- trainControl(method = "repeatedcv", 
                                  number = 10,  # 10-fold CV
                                  repeats = repeats,  # 繰り返しの回数
                                  savePredictions = "final")
    model <- train(price ~ ., data = independent.var, 
                   method = "glmnet",  # Lasso回帰
                   trControl = train_control,
                   tuneLength = 10)  # チューニングするλの数
    
    predictions[[seed]] <- model$pred
  }
  
  # すべての予測値を結合
  all_preds <- do.call(rbind, predictions)
  results[[paste("repeats", repeats, sep = "_")]] <- all_preds
}

統計量の計算

calculate_metrics <- function(predictions) {
  actual <- as.numeric(predictions$obs)
  predicted <- as.numeric(predictions$pred)
  
  mse <- mean((actual - predicted)^2)
  rmse <- sqrt(mse)
  mae <- mean(abs(actual - predicted))
  rsquared <- cor(actual, predicted)^2
  
  # 95%信頼区間の計算
  mse_se <- sd((actual - predicted)^2) / sqrt(length(actual))
  mse_ci_lower <- mse - 1.96 * mse_se
  mse_ci_upper <- mse + 1.96 * mse_se
  
  rmse_se <- sd(sqrt((actual - predicted)^2)) / sqrt(length(actual))
  rmse_ci_lower <- rmse - 1.96 * rmse_se
  rmse_ci_upper <- rmse + 1.96 * rmse_se
  
  mae_se <- sd(abs(actual - predicted)) / sqrt(length(actual))
  mae_ci_lower <- mae - 1.96 * mae_se
  mae_ci_upper <- mae + 1.96 * mae_se
  
  list(MSE = mse, RMSE = rmse, MAE = mae, R2 = rsquared, 
       MSE_CI_Lower = mse_ci_lower, MSE_CI_Upper = mse_ci_upper,
       RMSE_CI_Lower = rmse_ci_lower, RMSE_CI_Upper = rmse_ci_upper,
       MAE_CI_Lower = mae_ci_lower, MAE_CI_Upper = mae_ci_upper)
}

metrics <- data.frame(repeats = integer(), 
                      MSE = numeric(), RMSE = numeric(), MAE = numeric(), R2 = numeric(), 
                      MSE_CI_Lower = numeric(), MSE_CI_Upper = numeric(),
                      RMSE_CI_Lower = numeric(), RMSE_CI_Upper = numeric(),
                      MAE_CI_Lower = numeric(), MAE_CI_Upper = numeric())

for (repeats in 1:30) {
  preds <- results[[paste("repeats", repeats, sep = "_")]]
  metrics_res <- calculate_metrics(preds)
  metrics <- rbind(metrics, data.frame(repeats = repeats, 
                                       MSE = metrics_res$MSE, 
                                       RMSE = metrics_res$RMSE, 
                                       MAE = metrics_res$MAE, 
                                       R2 = metrics_res$R2,
                                       MSE_CI_Lower = metrics_res$MSE_CI_Lower,
                                       MSE_CI_Upper = metrics_res$MSE_CI_Upper,
                                       RMSE_CI_Lower = metrics_res$RMSE_CI_Lower,
                                       RMSE_CI_Upper = metrics_res$RMSE_CI_Upper,
                                       MAE_CI_Lower = metrics_res$MAE_CI_Lower,
                                       MAE_CI_Upper = metrics_res$MAE_CI_Upper))
}

# 並行処理の停止
stopCluster(cl)

結果の表示

print(metrics)
write.csv(metrics, "metrics.csv")
repeats MSE RMSE MAE R2
1 4640526 2154.188 1622.182 0.44507
2 4657895 2158.216 1621.973 0.442928
3 4691725 2166.039 1613.545 0.439773
4 4667408 2160.418 1614.21 0.442108
5 4653046 2157.092 1617.974 0.44343
6 4664083 2159.649 1621.68 0.442089
7 4671125 2161.278 1622.283 0.441242
8 4685531 2164.609 1625.766 0.439507
9 4679912 2163.311 1624.811 0.440183
10 4681543 2163.687 1625.365 0.439969
11 4676046 2162.417 1623.825 0.440633
12 4671917 2161.462 1622.463 0.441123
13 4663648 2159.548 1620.973 0.442115
14 4661294 2159.003 1621.043 0.442395
15 4643063 2154.777 1622.585 0.444758
16 4639505 2153.951 1621.533 0.445182
17 4638211 2153.651 1622.331 0.445373
18 4648235 2155.976 1618.812 0.443981
19 4646286 2155.524 1618.784 0.444212
20 4650611 2156.528 1617.275 0.443693
21 4654001 2157.313 1623.167 0.443283
22 4656675 2157.933 1620.82 0.442949
23 4651028 2156.624 1620.096 0.443641
24 4658572 2158.372 1620.891 0.442723
25 4660592 2158.84 1622.141 0.442491
26 4670758 2161.194 1621.974 0.441388
27 4662009 2159.169 1622.854 0.442308
28 4662111 2159.192 1620.957 0.442302
29 4657345 2158.088 1619.483 0.44291
30 4663181 2159.44 1619.394 0.442219

MSEのプロット

MSEとMSEの95%信頼区間のプロットをみてみよう。

library(ggplot2)
ggplot(metrics, aes(x = repeats, y = MSE)) +
  geom_line(color = "blue") +
  geom_point(color = "blue") +
  geom_ribbon(aes(ymin = MSE_CI_Lower, ymax = MSE_CI_Upper), alpha = 0.2, fill = "gray") +
  ylab("MSE") +
  xlab("Number of Repeats") +
  ggtitle("MSE over Repeats with 95% CI") +
  theme_minimal()

RMSEのプロット

次にRMSEとRMSEの95%信頼区間のプロットをみてみよう。

ggplot(metrics, aes(x = repeats, y = RMSE)) +
  geom_line(color = "red") +
  geom_point(color = "red") +
  geom_ribbon(aes(ymin = RMSE_CI_Lower, ymax = RMSE_CI_Upper), alpha = 0.2, fill = "gray") +
  ylab("RMSE") +
  xlab("Number of Repeats") +
  ggtitle("RMSE over Repeats with 95% CI") +
  theme_minimal()

MAEのプロット

MAEとMAEの95%信頼区間のプロットをみてみよう。

ggplot(metrics, aes(x = repeats, y = MAE)) +
  geom_line(color = "green") +
  geom_point(color = "green") +
  geom_ribbon(aes(ymin = MAE_CI_Lower, ymax = MAE_CI_Upper), alpha = 0.2, fill = "gray") +
  ylab("MAE") +
  xlab("Number of Repeats") +
  ggtitle("MAE over Repeats with 95% CI") +
  theme_minimal()

R-squaredの結果をまとめたグラフ

一応、R二乗値もみよう。

# R-squaredのプロット
ggplot(metrics, aes(x = repeats, y = R2)) +
  geom_line(color = "purple") +
  geom_point(color = "purple") +
  ylab("R-squared") +
  xlab("Number of Repeats") +
  ggtitle("R-squared over Repeats") +
  theme_minimal()

R二乗値はあまりどうでもいいので無視をするべきだろう。

小括

MSE、RMSE、MAEは反復回数が多くなるにしたがって95%区間が狭まるため、やはり反復回数を多くすれば安定するというのは正しいようだ。

各統計量をみると、MSE、RMSEが最も少なかったのは17回目で、それぞれ4638211.3、2153.65であった。17回の次は、16回、1回、15回、19回、18回となった。されほど大差はないようにも見えるが、反復回数は15-19のあいだに集中している。これはこのデータでに関しては、ということではあるが。

MAEは一桁を無視すると、20、19、18、30、29、23回という順番でMSE、RMSEと同じ傾向とはいかないようだ。

統計学的な根拠をもって、回数を決めるというのはできないようだ。
安定というのは、統計量の95%信頼区間が狭いことでもあるので、95%信頼区間から変動点の検出を行ってみる。

RMSEの変動点の検出

# 分析用のデータを作成
RMSE.dat <- subset(metrics, select = c(repeats, RMSE, RMSE_CI_Lower, RMSE_CI_Upper))

# 必要なライブラリをロード
library(tidyverse)
library(changepoint)

# 95%信頼区間の幅を計算
RMSE.dat$CI_width <- RMSE.dat$RMSE_CI_Upper - RMSE.dat$RMSE_CI_Lower

# 変動点検出
cpt <- cpt.mean(RMSE.dat$CI_width, method = "PELT")

# 結果のプロット
ggplot(RMSE.dat, aes(x = repeats, y = CI_width)) +
  geom_line() +
  geom_point() +
  geom_vline(xintercept = cpts(cpt), linetype = "dashed", color = "red") +
  labs(title = "95%信頼区間幅の変動点検出", x = "Repeats", y = "CI Width") +
  theme_minimal()

# 検出された変動点の表示
cpts(cpt)

結果

1 2 3 4 5 6 7 8 9 10 12 14 16 19 22 26

RMSEのそのまま値だと変動点を見出すのは困難なようだ。

RMSEの平均と分散の変動点を検出

cpt_meanvar <- cpt.meanvar(RMSE.dat$CI_width, method = "PELT")

# 結果のプロット
ggplot(RMSE.dat, aes(x = repeats, y = CI_width)) +
  geom_line() +
  geom_point() +
  geom_vline(xintercept = cpts(cpt_meanvar), linetype = "dashed", color = "red") +
  labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") +
  theme_minimal()

# 検出された変動点の表示
cpts(cpt_meanvar)

結果。

6 14 22

MSEの平均と分散の変動点を検出

# 分析用のデータを作成
MSE.dat <- subset(metrics, select = c(repeats, MSE, MSE_CI_Lower, MSE_CI_Upper))

# 95%信頼区間の幅を計算
MSE.dat$CI_width <- MSE.dat$MSE_CI_Upper - MSE.dat$MSE_CI_Lower

# 平均と分散の変動点を検出
cpt_meanvar.MSE <- cpt.meanvar(MSE.dat$CI_width, method = "PELT")

# 結果のプロット
ggplot(MSE.dat, aes(x = repeats, y = CI_width)) +
  geom_line() +
  geom_point() +
  geom_vline(xintercept = cpts(cpt_meanvar.MSE), linetype = "dashed", color = "red") +
  labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") +
  theme_minimal()

# 検出された変動点の表示
cpts(cpt_meanvar.MSE)

結果。

6 14 19 24 26

MAEの平均と分散の変動点を検出

# 分析用のデータを作成
MAE.dat <- subset(metrics, select = c(repeats, MAE, MAE_CI_Lower, MAE_CI_Upper))

# 95%信頼区間の幅を計算
MAE.dat$CI_width <- MAE.dat$MAE_CI_Upper - MAE.dat$MAE_CI_Lower

# 平均と分散の変動点を検出
cpt_meanvar.MAE <- cpt.meanvar(MAE.dat$CI_width, method = "PELT")

# 結果のプロット
ggplot(MAE.dat, aes(x = repeats, y = CI_width)) +
  geom_line() +
  geom_point() +
  geom_vline(xintercept = cpts(cpt_meanvar.MAE), linetype = "dashed", color = "red") +
  labs(title = "95%信頼区間幅の変動点検出 (平均と分散)", x = "Repeats", y = "CI Width") +
  theme_minimal()

# 検出された変動点の表示
cpts(cpt_meanvar.MAE)

結果。

6 14 22

結論

計算して気づいたがRMSEとMAEの95%信頼区間幅同じになるようだ。そのことを考慮に入れて、RMSEとMSEとMAEを見比べると6回と14回がいずれも変動点と検出されているので、おそらく6回か14回というのがこの分析が進める回数といえるだろう。
6回というのは5から10回という経験則を示したものであり、データによって違うことを考慮すれば、10回に近い方が安全であろう。
今回のようなサンプルサイズが小さなデータであれば、14回が適切な回数となるだろう。

この回数はこのデータで計算した時の数字である。どのデータにも当てはまるものではない。
とはいえ、データごとに変動点は算出できるので、データごとに判断することは可能である。