井出草平の研究ノート

R glmnetパッケージとStataのLassoを比較する[R][Stata]

LASOOPACKの作者が簡単に解説していたので、まずはそれを参考にシミュレーションをしてみる。

statalasso.github.io

基本形

Stata

// データの読み込み
webuse auto, clear
drop if rep78==.

// Lasso 10-folds CV
lasso linear price mpg-foreign, grid(1, min(100))
lassoselect lambda = 100
lassocoef, display(coef, penalized)

結果。

------------------------
             |    active
-------------+----------
       rep78 |  126.2449
    headroom |  -453.089
      weight |  1.970002
        turn | -41.05395
displacement |  13.77457
  gear_ratio |  -12.9703
     foreign |  2921.761
       _cons | -841.6243
------------------------

ラムダは100に指定。grid(1, min(100))は1から100までの連続したラムダ値を試行して最適なモデルを見つけるというコマンド。指定するラムダがこの範囲に入っていないとエラーが出る。

R glmnetパッケージ

準備

library(haven)
library(tidyr)
auto <- haven::read_dta("http://www.stata-press.com/data/r9/auto.dta")
auto <- auto %>% drop_na()  # 全カラムに対してNAがない行を抽出

n <- nrow(auto)
price <- auto$price
X <- auto[, c("mpg", "rep78", "headroom", "trunk", "weight", "length",
              "turn", "displacement", "gear_ratio", "foreign")]
X$foreign <- as.integer(X$foreign)
X <- as.matrix(X)

Lasso

r<-glmnet(X,price,alpha=1,lambda=100, thresh=1e-15)
coef(r,s=100)

thresh=1e-15は収束基準の閾値を指定している。デフォルト値は1e-7だがStataと一致させるには、1e-15を指定しないといけないようだ。 s=100はラムダの値を入れる。

                      s1
(Intercept)  -841.625330
mpg             .       
rep78         126.244945
headroom     -453.089041
trunk           .       
weight          1.970003
length          .       
turn          -41.053924
displacement   13.774563
gear_ratio    -12.970463
foreign      2921.761357

乱数

Stataはrseed(123)で、Rはset.seed(123)で乱数シードを固定する。
StataとRの乱数はどちらもメルセンヌ・ツイスターが使われているが、細かい計算手順が違うため、同じシードを入れても同じ乱数は発生できない。
Pythonで乱数を発生させ、StataとRで読み込んで見たが失敗、Stataで乱数を発生させcsvで吐き出し、Rで読み込んでみたが、それも失敗。
良い方法があるのかもしれないが、今回は良い方法を探し当てることはできなかった。

その過程でRに乱数からシードを推定するという機能があるのを知った。

# 読み込んだ乱数をシード値に変換
seed_value <- as.integer(sum(random_values$random_value) %% 1e9)

Stataでベストラムダを算出してR glmnetパッケージで計算する

Stataでまずベストラムダを計算する。

// データの読み込み
webuse auto, clear
drop if rep78==.

. lasso linear price mpg-foreign, selection(cv, folds(10)) nolog rseed(123)

Lasso linear model                          No. of obs        =         69
                                            No. of covariates =         10
Selection: Cross-validation                 No. of CV folds   =         10

--------------------------------------------------------------------------
         |                                No. of      Out-of-      CV mean
         |                               nonzero       sample   prediction
      ID |     Description      lambda     coef.    R-squared        error
---------+----------------------------------------------------------------
       1 |    first lambda    1584.189         0      -0.0126      8465030
      25 |   lambda before    169.8674         5       0.4283      4779268
    * 26 | selected lambda    154.7769         5       0.4289      4773852
      27 |    lambda after    141.0269         5       0.4288      4775047
      30 |     last lambda    106.6817         6       0.4264      4794955
--------------------------------------------------------------------------
* lambda selected by cross-validation.

lamda = 154.7769であった。

結果も計算しておこう。

. lasso linear price mpg-foreign, lambda(`e(lambda)') rseed(123)
. stimates store CV
. lassocoef CV, display(coef, penalized)

------------------------
             |        CV
-------------+----------
       rep78 |  98.46326
    headroom | -332.6659
      weight |  1.661591
displacement |  12.71576
     foreign |  2692.259
       _cons | -1566.399
------------------------

次はR glmnetパッケージ。

r2<-glmnet(X,price,alpha=1, thresh=1e-15)
coef(r2,s=154.7769)
                      s1
(Intercept)  -1566.39783
mpg              .      
rep78           98.46325
headroom      -332.66581
trunk            .      
weight           1.66159
length           .      
turn             .      
displacement    12.71576
gear_ratio       .      
foreign       2692.25840

ラムダが同じであれば、StataとR glmnetパッケージは同じ計算ができることが分かった。双方の結果が異なる場合は、CVでベストラムダを探索する際の乱数の違いによって生じるということである。