井出草平の研究ノート

WindowsでPythonを用いた統計解析を行う際に計算にGPUが必要な場合の環境構築

WindowsマシンでPythonでの統計解析、特にベイズ推定を行うときにGPUに計算を投げる場合にはWSL2上のLinuxにAnaconda環境を構築する必要である。理由はいくつかある。

  • JAX/NumPyroのGPU対応はLinux前提で成熟しており再現性が高いため
  • CUDA・cuDNN・NCCLの整合がLinuxのconda-forgeで一貫管理できるため
  • PyMCのJAX経路がLinuxで最も素直に動作し情報資産も豊富なため
  • WSL2はWindowsドライバを用いながらLinux用CUDAを直接使えるため
  • 数値計算系のビルド・ツールチェーンがLinux最適化でトラブルが少ないため
  • アップデートと移植の速度・安定性がLinuxのほうが高く運用コストが低いため
  • WSL内にAnacondaを置けばWindowsPythonとの混線を避け環境破綻を防げるため

多くの頻度論の計算であれば、Windows上のAnaconda環境で十分だが、GPUに計算を投げる場合にはそうはいかない。macOSであればいいのかというと、そもそもmacOSはCUDAをサポートしない=NVIDIAのボードを使えないので論外である。最適なのは、ネイティブのLinuxNVIDIA環境である。

1. WSL2とGPU準備

NVIDIA Windowsドライバを最新版へ更新(「CUDA対応・DCH」推奨)。

管理者PowerShellでWSLを導入・更新:

wsl --install -d Ubuntu
wsl --update
wsl --set-default-version 2

Ubuntu初回起動後に確認:

nvidia-smi

GPUが見えればOK(WSL内で実行)

2. Conda(Linux側)をホームにインストール

Anaconda(またはMiniconda)Linux版を/home配下に入れる(例:/home//anaconda3)。

conda初期化:

~/anaconda3/bin/conda init bash
exec $SHELL

.condarcをLinux側に設定(最重要):

nano ~/.condarc
envs_dirs:
  - /home/<user>/anaconda3/envs
pkgs_dirs:
  - /home/<user>/anaconda3/pkgs
channels:
  - conda-forge
  - defaults
channel_priority: strict

これで新規環境が/mnt/c(Windows側)ではなくLinuxの/home側に作られる。

3. GPU用ベース環境の作成

conda create -n bayes-gpu-wsl python=3.11 -y
conda activate bayes-gpu-wsl
# 主要パッケージ(jax/jaxlibはCUDA版をconda-forgeで取得)
conda install -c conda-forge pymc arviz numpyro blackjax jupyter jax jaxlib -y

自動でCUDA 12.9系, cuDNN, NCCLなど必要ランタイムが入る

4. 動作検証(JAXがGPUを見るか)

python -c "import jax; print('devices:', jax.devices())"

=> [CudaDevice(id=0)] が出ればOK

5. JupyterLab起動(ブラウザから使用)

mkdir -p ~/bayes-proj && cd ~/bayes-proj
jupyter lab --no-browser --ip=127.0.0.1 --port=8888

表示された http://127.0.0.1:8888/lab?token=... を開く

ブラウザでJupterLabを開くことができる

カーネルはbayes-gpu-wsl環境のPython(「Python 3 (ipykernel)」で実体が/home/.../envs/bayes-gpu-wsl/bin/python)を選ぶ。

6. PyMC+NumPyroでのGPUサンプリング最小例

JupyterLab内で下記のコードを走らせて、テスト

import pymc as pm, numpy as np, arviz as az
from pymc.sampling.jax import sample_numpyro_nuts

rng = np.random.default_rng(123)
x = rng.normal(size=200)
y = 1.5 + 2.0*x + rng.normal(scale=0.5, size=200)

with pm.Model() as m:
    a = pm.Normal("alpha", 0, 10)
    b = pm.Normal("beta",  0, 10)
    s = pm.HalfNormal("sigma", 1)
    pm.Normal("obs", mu=a + b*x, sigma=s, observed=y)

    idata = sample_numpyro_nuts(
        draws=1000, tune=500, chains=4, target_accept=0.9,
        chain_method="vectorized",                 # 1GPUで複数チェーン
        idata_kwargs={"log_likelihood": True},     # 事後評価用途に便利
    )

print(az.summary(idata, var_names=["alpha","beta","sigma"]))
  1. バックアップと別PC展開

余計な依存を省いた履歴ベースのエクスポート

conda env export --from-history -n bayes-gpu-wsl > bayes-gpu-wsl.yml

別PCでの復元

conda env create -f bayes-gpu-wsl.yml
conda activate bayes-gpu-wsl
python -c "import jax; print(jax.devices())"   # [CudaDevice(id=0)] を確認

チェック用スニペット(情報出力)

6の前くらいにするとよい

import sys, jax, pymc as pm, numpyro, arviz as az
print(sys.executable)
print("PyMC:", pm.__version__)
print("JAX:", jax.__version__)
print("backend:", jax.default_backend())
print("devices:", jax.devices())

正常のものの例 - .../envs/bayes-gpu-wsl/bin/python - backend: gpu - devices: [CudaDevice(id=0)]

注意事項

最重要:/mnt/cを使わない

プロジェクトやConda環境は/home配下。I/Oも速く、PATH混線も防げる。WindowsのCondaを使うとかなり遅くなるので注意

.condarcの優先順位

envs_dirs と pkgs_dirs が /home側を指すこと。ここがズレると/mnt/c/...に作られて壊れやすい。ここではまってしまって困った。

カーネルの取り違い

NotebookのKernel → Change Kernelで、実体パスが/home/のPythonを選ぶ。

JAXの並列警告

1GPUで複数チェーン時は chain_method="vectorized" を付ける(推奨)。家庭用パソコンではGPU1枚体制で計算をするはずなので。

ArviZのlooでlog_likelihoodが無い

idata_kwargs={"log_likelihood": True}sample_numpyro_nutsに渡す。

VS Code Remote WSLが落ちる/重い

一時的にブラウザJupyterLab運用が安定。どうしてもVS Codeなら導入する。

更新

conda activate bayes-gpu-wsl
conda update -n base -c conda-forge conda
conda update -c conda-forge --all

ワンライナー起動(~/.bashrcに追記)

環境を立ち上げるときに

alias jl='conda activate bayes-gpu-wsl >/dev/null 2>&1 && jupyter lab --no-browser --ip=127.0.0.1 --port=8888'