WindowsマシンでPythonでの統計解析、特にベイズ推定を行うときにGPUに計算を投げる場合にはWSL2上のLinuxにAnaconda環境を構築する必要である。理由はいくつかある。
- JAX/NumPyroのGPU対応はLinux前提で成熟しており再現性が高いため
- CUDA・cuDNN・NCCLの整合がLinuxのconda-forgeで一貫管理できるため
- PyMCのJAX経路がLinuxで最も素直に動作し情報資産も豊富なため
- WSL2はWindowsドライバを用いながらLinux用CUDAを直接使えるため
- 数値計算系のビルド・ツールチェーンがLinux最適化でトラブルが少ないため
- アップデートと移植の速度・安定性がLinuxのほうが高く運用コストが低いため
- WSL内にAnacondaを置けばWindows側Pythonとの混線を避け環境破綻を防げるため
多くの頻度論の計算であれば、Windows上のAnaconda環境で十分だが、GPUに計算を投げる場合にはそうはいかない。macOSであればいいのかというと、そもそもmacOSはCUDAをサポートしない=NVIDIAのボードを使えないので論外である。最適なのは、ネイティブのLinux+NVIDIA環境である。
1. WSL2とGPU準備
NVIDIA Windowsドライバを最新版へ更新(「CUDA対応・DCH」推奨)。
管理者PowerShellでWSLを導入・更新:
wsl --install -d Ubuntu wsl --update wsl --set-default-version 2
Ubuntu初回起動後に確認:
nvidia-smi
GPUが見えればOK(WSL内で実行)
2. Conda(Linux側)をホームにインストール
Anaconda(またはMiniconda)Linux版を/home配下に入れる(例:/home/
conda初期化:
~/anaconda3/bin/conda init bash exec $SHELL
.condarcをLinux側に設定(最重要):
nano ~/.condarc
envs_dirs: - /home/<user>/anaconda3/envs pkgs_dirs: - /home/<user>/anaconda3/pkgs channels: - conda-forge - defaults channel_priority: strict
これで新規環境が/mnt/c(Windows側)ではなくLinuxの/home側に作られる。
3. GPU用ベース環境の作成
conda create -n bayes-gpu-wsl python=3.11 -y conda activate bayes-gpu-wsl # 主要パッケージ(jax/jaxlibはCUDA版をconda-forgeで取得) conda install -c conda-forge pymc arviz numpyro blackjax jupyter jax jaxlib -y
自動でCUDA 12.9系, cuDNN, NCCLなど必要ランタイムが入る
4. 動作検証(JAXがGPUを見るか)
python -c "import jax; print('devices:', jax.devices())"
=> [CudaDevice(id=0)] が出ればOK
5. JupyterLab起動(ブラウザから使用)
mkdir -p ~/bayes-proj && cd ~/bayes-proj jupyter lab --no-browser --ip=127.0.0.1 --port=8888
表示された http://127.0.0.1:8888/lab?token=... を開く
ブラウザでJupterLabを開くことができる
カーネルはbayes-gpu-wsl環境のPython(「Python 3 (ipykernel)」で実体が/home/.../envs/bayes-gpu-wsl/bin/python)を選ぶ。
6. PyMC+NumPyroでのGPUサンプリング最小例
JupyterLab内で下記のコードを走らせて、テスト
import pymc as pm, numpy as np, arviz as az
from pymc.sampling.jax import sample_numpyro_nuts
rng = np.random.default_rng(123)
x = rng.normal(size=200)
y = 1.5 + 2.0*x + rng.normal(scale=0.5, size=200)
with pm.Model() as m:
a = pm.Normal("alpha", 0, 10)
b = pm.Normal("beta", 0, 10)
s = pm.HalfNormal("sigma", 1)
pm.Normal("obs", mu=a + b*x, sigma=s, observed=y)
idata = sample_numpyro_nuts(
draws=1000, tune=500, chains=4, target_accept=0.9,
chain_method="vectorized", # 1GPUで複数チェーン
idata_kwargs={"log_likelihood": True}, # 事後評価用途に便利
)
print(az.summary(idata, var_names=["alpha","beta","sigma"]))
- バックアップと別PC展開
余計な依存を省いた履歴ベースのエクスポート
conda env export --from-history -n bayes-gpu-wsl > bayes-gpu-wsl.yml
別PCでの復元
conda env create -f bayes-gpu-wsl.yml conda activate bayes-gpu-wsl python -c "import jax; print(jax.devices())" # [CudaDevice(id=0)] を確認
チェック用スニペット(情報出力)
6の前くらいにするとよい
import sys, jax, pymc as pm, numpyro, arviz as az
print(sys.executable)
print("PyMC:", pm.__version__)
print("JAX:", jax.__version__)
print("backend:", jax.default_backend())
print("devices:", jax.devices())
正常のものの例 - .../envs/bayes-gpu-wsl/bin/python - backend: gpu - devices: [CudaDevice(id=0)]
注意事項
最重要:/mnt/cを使わない
プロジェクトやConda環境は/home配下。I/Oも速く、PATH混線も防げる。WindowsのCondaを使うとかなり遅くなるので注意
.condarcの優先順位
envs_dirs と pkgs_dirs が /home側を指すこと。ここがズレると/mnt/c/...に作られて壊れやすい。ここではまってしまって困った。
カーネルの取り違い
NotebookのKernel → Change Kernelで、実体パスが/home/のPythonを選ぶ。
JAXの並列警告
1GPUで複数チェーン時は chain_method="vectorized" を付ける(推奨)。家庭用パソコンではGPU1枚体制で計算をするはずなので。
ArviZのlooでlog_likelihoodが無い
idata_kwargs={"log_likelihood": True}をsample_numpyro_nutsに渡す。
VS Code Remote WSLが落ちる/重い
一時的にブラウザJupyterLab運用が安定。どうしてもVS Codeなら導入する。
更新
conda activate bayes-gpu-wsl conda update -n base -c conda-forge conda conda update -c conda-forge --all
ワンライナー起動(~/.bashrcに追記)
環境を立ち上げるときに
alias jl='conda activate bayes-gpu-wsl >/dev/null 2>&1 && jupyter lab --no-browser --ip=127.0.0.1 --port=8888'