井出草平の研究ノート

入れ子交差検証/Nested Cross-Validation[R]

入れ子交差検証(Nested Cross-Validation)は、モデルの性能評価とλの選定におけるバイアスを排除するための検証手法である。この方法は、外側と内側という比喩を用いて説明される。外側のクロスバリデーションと内側のクロスバリデーションの二重構造から成り立っている。具体的には、データセットを複数のサブセットに分割し、外側の反復でモデルの性能を評価し、内側の反復ループでλの選定を行う。

この手順を踏む理由は、一般的なk-foldクロスバリデーションの問題を克服するためである。k-foldクロスバリデーションでは、データセットをk個のフォールドに分割し、それぞれのフォールドを一度ずつテストデータとして使用し、残りのフォールドを訓練データとして使用する。この方法では、λの選定とモデルの性能評価が同じデータセットで行われるため、バイアスが生じる可能性がある。

λはモデルの性能評価によって選定され、モデルの性能評価にはRMSEやMAEといった統計量が使用される。入れ子交差検証では、この選定と評価が同じデータセットで行われることによるバイアスを防ぐことが目的である。

入れ子交差検証では、各フォールドで得られた性能指標を平均し、異なるλの値に対してこのプロセスを繰り返す。最も良い性能を示したλの値が最適なλとして選ばれる。λを選定した後は、k-foldクロスバリデーションと同じ手順でLassoの推定が行われるため、仮に入れ子交差検証とk-foldクロスバリデーションで選定されたλが同じであれば、Lassoの推定値も同じになる。

要するに、入れ子交差検証はλの選定過程をより厳密にすることで、より良いλの選定を行うための技法であると考えられる。

入れ子交差検証の流れ

外側ループ(反復):

フォールド1 (テストセット), フォールド2-5 (訓練セット) -> 内側ループでハイパーパラメータチューニング -> 最適モデルでテストセット評価
フォールド2 (テストセット), フォールド1,3-5 (訓練セット) -> 内側ループでハイパーパラメータチューニング -> 最適モデルでテストセット評価
...
k. フォールドk (テストセット), フォールド1-(k-1) (訓練セット) -> 内側ループでハイパーパラメータチューニング -> 最適モデルでテストセット評価

外側ループでは、データセットをk個のフォールドに分割し、それぞれのフォールドを一度ずつテストデータとして使用し、残りのフォールドを訓練データとして使用する。例えば、5フォールドのクロスバリデーションを行う場合、以下のようになる。

最初のステップでは、フォールド1がテストセットとなり、フォールド2-5が訓練セットとなる。この訓練セットに対して、内側ループでハイパーパラメータのチューニングを行い、最適なモデルを見つける。その後、この最適モデルを用いてフォールド1のテストセットを評価する。

次に、フォールド2がテストセットとなり、フォールド1と3-5が訓練セットとなる。同様に、内側ループでハイパーパラメータのチューニングを行い、最適なモデルを見つけた後に、フォールド2のテストセットを評価する。

これをk回繰り返し、各フォールドをテストセットとして評価することで、全体のモデル評価を行う。

内側ループ(反復):

内側ループでは、外側ループで選ばれた訓練セットをさらにm個のフォールドに分割し、ハイパーパラメータのチューニングを行う。このプロセスは以下の通りである。

訓練セットをm個のフォールドに分割し、各フォールドを一度ずつバリデーションデータとして使用し、残りのフォールドを訓練データとして使用する。

各ハイパーパラメータ(λ)の組み合わせについて、m回のクロスバリデーションを行い、それぞれのバリデーションセットでのモデルの性能を評価する。

すべてのハイパーパラメータ(λ)の組み合わせの中から、バリデーションセットで最も良い性能を示したハイパーパラメータ(λ)の組み合わせを選ぶ。

この内側ループで選ばれた最適なハイパーパラメータ(λ)を用いて、外側ループの訓練セット全体で最終モデルを訓練し、そのモデルを外側ループのテストセットで評価する。

Rでの実装

rdrr.io

# データの読み込みと前処理

library(haven)
library(tidyr)
library(dplyr)

auto <- haven::read_dta("http://www.stata-press.com/data/r9/auto.dta")
auto <- auto %>% drop_na()  # 全カラムに対してNAがない行を抽出
auto$foreign <- as.integer(auto$foreign)

n <- nrow(auto)
price <- auto$price
X <- auto[, c("mpg", "rep78", "headroom", "trunk", "weight", "length",
              "turn", "displacement", "gear_ratio", "foreign")]
X <- as.matrix(X)

並列処理の設定

library(doParallel)
num_cores <- detectCores() - 2  # システム全体のコア数から2つを引いた数を使用
cl <- makeCluster(num_cores)
registerDoParallel(cl)

入れ子交差検証の実行

library(glmnet)
library(nestedcv)

set.seed(123)  # 再現性のためのシード設定

result <- nestcv.glmnet(
  y = price,
  x = X,
  family = "gaussian",
  alphaSet = 1,
  n_outer_folds = 5,  # 外側ループのフォールド数
  n_inner_folds = 5,  # 内側ループのフォールド数
  parallel = TRUE  # 並列処理を有効にする
)

# 並列処理の停止
stopCluster(cl)

結果の表示

print(result)
print(result[["final_coef"]])
Nested cross-validation with glmnet
No filter

Final parameters:
lambda   alpha  
 128.5     1.0  

Final coefficients:
 (Intercept)      foreign     headroom        rep78 displacement       weight 
   -1831.698     2825.830     -394.122      123.420       13.211        1.736 

Result:
     RMSE    Rsquared         MAE   
2149.8921      0.4552   1666.7084 

安定性のシミュレーション

# 必要なライブラリの読み込み
library(haven)
library(tidyr)
library(glmnet)
library(nestedcv)
library(dplyr)
library(doParallel)

# データの読み込みと前処理
auto <- haven::read_dta("http://www.stata-press.com/data/r9/auto.dta")
auto <- auto %>% drop_na()  # 全カラムに対してNAがない行を抽出
auto$foreign <- as.integer(auto$foreign)

price <- auto$price
X <- auto[, c("mpg", "rep78", "headroom", "trunk", "weight", "length",
              "turn", "displacement", "gear_ratio", "foreign")]
X <- as.matrix(X)

# 並列処理の設定
num_cores <- detectCores() - 2  # システム全体のコア数から2つを引いた数を使用
cl <- makeCluster(num_cores)
registerDoParallel(cl)

# 結果を格納するリストを初期化
selected_variables <- list()
coefficients <- list()
jaccard_indices <- numeric()

# 50回のシミュレーション
for (i in 1:50) {
  set.seed(i)  # 異なる乱数シードを設定
  
  # ネステッドクロスバリデーションの実行
  result <- nestcv.glmnet(
    y = price,
    x = X,
    family = "gaussian",
    alphaSet = 1,
    n_outer_folds = 5,  # 外側ループのフォールド数
    n_inner_folds = 5,  # 内側ループのフォールド数
    parallel = TRUE  # 並列処理を有効にする
  )
  
  # 最終モデルの係数を取得
  final_coef <- result$final_fit$glmnet.fit$beta[, result$final_fit$glmnet.fit$lambda == result$final_fit$lambda.min]
  non_zero_coefs <- final_coef[final_coef != 0]
  
  # 選ばれた変数とその係数を保存
  selected_variables[[i]] <- names(non_zero_coefs)
  coefficients[[i]] <- non_zero_coefs
}

# 並行処理の停止
stopCluster(cl)

# 選ばれた変数の回数をカウント
variable_counts <- table(unlist(selected_variables))

# 選ばれた変数の推定値の平均と標準偏差を計算
variable_stats <- lapply(names(variable_counts), function(var) {
  coefs <- unlist(lapply(coefficients, function(coef) if(var %in% names(coef)) coef[var] else NA))
  coefs <- coefs[!is.na(coefs)]
  mean_coef <- mean(coefs, na.rm = TRUE)
  sd_coef <- sd(coefs, na.rm = TRUE)
  data.frame(variable = var, mean = mean_coef, sd = sd_coef)
})

variable_stats <- do.call(rbind, variable_stats)

# Jaccard係数の計算
for (i in 1:(length(selected_variables) - 1)) {
  for (j in (i + 1):length(selected_variables)) {
    set_i <- selected_variables[[i]]
    set_j <- selected_variables[[j]]
    if (length(set_i) > 0 && length(set_j) > 0) {
      jaccard_index <- length(intersect(set_i, set_j)) / length(union(set_i, set_j))
      jaccard_indices <- c(jaccard_indices, jaccard_index)
    }
  }
}

average_jaccard <- mean(jaccard_indices, na.rm = TRUE)

変数が選ばれた回数

variable_counts_df <- as.data.frame(variable_counts)
colnames(variable_counts_df) <- c("Variable", "Count")
print(variable_counts_df)
       Variable Count
1  displacement    50
2       foreign    50
3    gear_ratio    28
4      headroom    50
5        length    17
6           mpg    12
7         rep78    50
8         trunk    12
9          turn    40
10       weight    50

各変数の推定値の平均と標準偏

variable_stats_df <- as.data.frame(variable_stats)
print(variable_stats_df)
       variable        mean          sd
1  displacement   13.167347   0.6169471
2       foreign 2983.161668 193.8486265
3    gear_ratio -164.675853 101.6676337
4      headroom -462.243290  78.4175351
5        length  -38.720636  18.9774730
6           mpg   -9.721489   5.2209098
7         rep78  132.597007  22.3809795
8         trunk   35.476713  15.3503600
9          turn  -61.485420  31.5922149
10       weight    2.397636   0.8185932

平均Jaccard係数

cat("Average Jaccard Index: ", average_jaccard)

Average Jaccard Index: 0.7713382

変数選択のパターンは、選ばれている変数はそのまま、選ばれにくかった変数も選ばれる回数が増加している。 そのため、Jaccard係数も上がっているのだろう。