NumPyroでMCMC - ロジスティックモデルを使った個体数増加のモデリング
概要
前から気になっていたNumpyroを試してみたので覚書。コードは以下のrepositoryにあります。
github.com
NumPyro↓ これを使うと手軽に確率モデリングのコードが書けるらしいです。
num.pyro.ai
お題
ロジスティックモデルにしたがう個体数増加のモデリングをやってみました。
ロジスティックモデルは、を個体数、を時間として以下の微分方程式で表されるモデルです。
ロジスティック方程式 - Wikipedia
が個体数増加速度、が環境収容力です。個体数は指数関数的に増えますが、環境収容力に近づくにつれ速度が抑えられます。菌や細胞を培養したときの挙動に近いです。
上の微分方程式を解くと以下の式になります。は個体数の初期値です。
今回のお題用に、平均が, , のロジスティックモデルにしたがうデータを用意しました。
の区間で、のガンマ分布*1にしたがう100個のデータを生成しました。
NumPyroを使ったマルコフ連鎖モンテカルロ(MCMC)法により、データから、元のパラメータの復元を試みます。
生成されたデータと、ロジスティックモデルにしたがう期待値 ± ガンマ分布のSDのplot。
コードの説明
ライブラリの読み込み
import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt %matplotlib inline import jax import numpyro import numpyro.distributions as dist import arviz as az
Toy dataの生成
上述の通りtoy dataを生成します。
def logistic(t, k, r, n0): n = k / (1. + (k / n0 - 1.) * np.exp(-r * t)) return n k = 10. r = 1. n0 = 0.1 scale = 0.1 t_max = 10. n_data = 100 df_data = pd.DataFrame({ 't': np.linspace(0, t_max , n_data) }) np.random.seed(seed = 1) df_data = df_data.assign( n = np.random.gamma(logistic(df_data['t'], k, r, n0) / scale, scale, n_data) )
モデルの記述
上述のデータの生成過程を記述します。事前分布はすべてコーシー or 半コーシー分布とします(無情報のつもり)。
データはnumpy arrayではなくjax numpy arrayとして扱う必要があります。ガンマ分布のパラメータの取り方がnumpyとnumpyroで違うので要注意です。
def logistic_jax(t, k, r, n0): n = k / (1. + (k / n0 - 1.) * jax.numpy.exp(-r * t)) return n def model(t, n, N): k = numpyro.sample('k', dist.HalfCauchy(10.)) r = numpyro.sample('r', dist.Cauchy(10.)) n0 = numpyro.sample('n0', dist.HalfCauchy(10.)) scale = numpyro.sample('scale', dist.HalfCauchy(10.)) n_mean = logistic_jax(t, k, r, n0) with numpyro.plate('data', N): numpyro.sample('n', dist.Gamma(n_mean / scale, 1. / scale), obs = n)
MCMCでサンプリング
MCMCによるサンプリングで事後分布を推定します。今回はNo-U-turn sampler (NUTS) を使います。Chainを4つ用意して並列化します。
iter_warmup = 200 iter_sample = 1000 n_chain = 4 numpyro.set_host_device_count(n_chain) nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size = True) mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup = iter_warmup, num_samples = iter_sample, num_chains = n_chain, chain_method = 'parallel') mcmc.run(jax.random.PRNGKey(0), t = df_data['t'].values, n = df_data['n'].values, N = df_data.shape[0])
結果の確認
サンプリングされたパラメータ(事後分布)を確認します。
az.plot_trace(mcmc)
4つのchainが正解の値に収束しました。
サンプリングされたパラメータの統計量 (平均, SD等) と、MCMCの収束判定の指標であるを確認します。
dict_diagnostics = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain = True), prob = 0.95, group_by_chain = True) df_diagnostics = pd.DataFrame.from_dict(dict_diagnostics) df_diagnostics
k | n0 | r | scale | |
---|---|---|---|---|
mean | 10.202814 | 0.096435 | 1.015158 | 0.082979 |
std | 0.183319 | 0.011746 | 0.033786 | 0.012336 |
median | 10.197577 | 0.095709 | 1.015514 | 0.081736 |
2.5% | 9.854436 | 0.073230 | 0.949016 | 0.060573 |
97.5% | 10.567705 | 0.118404 | 1.080463 | 0.108148 |
n_eff | 2063.584710 | 1622.019435 | 1583.550530 | 1993.970682 |
r_hat | 1.002008 | 1.001410 | 1.001241 | 1.000152 |
サンプリングされたパラメータの値は正解の値にほぼ等しく、また、がほぼ1であることからMCMCが収束していることが確認できました。
予測分布の生成(おまけ)
最後に、サンプリングされたパラメータの事後分布を使って、時間に対する個体数の予測分布を生成します。
解析的には求まらないので、サンプリングされたパラメータ (1,000個×4 chains分) からの分布を作ってサンプリングします。
pred = numpyro.infer.Predictive(model, mcmc.get_samples(group_by_chain = False)) n_bin = 100 t_pred = np.linspace(0, t_max, n_bin) n_pred = pred(jax.random.PRNGKey(0), t = t_pred, n = None, N = t_pred.shape[0])['n']
各の値に対して4,000個のがサンプリングされました。その平均とSDを元のデータに重ねてみます。
n_pred_mean = n_pred.mean(0) n_pred_std = n_pred.std(0) sns.scatterplot(x = 't', y = 'n', data = df_data) sns.lineplot(x = t_lineplot, y = n_pred_mean) plt.fill_between(x = t_lineplot, y1 = n_pred_mean - n_pred_std, y2 = n_pred_mean + n_pred_std, alpha = 0.2)
きれいに重なりました!
おわりに
昔stanでベイジアンモデリングの真似事をやったことがありますが、当時に比べるとかなり楽に書けた印象です。Pythonで手軽に書けるのはありがたいです。
今回はGPUが使えずJAXの恩恵にあずかれませんでした。今度dockerでGPUを使う方法を調べてみようと思います。
関連記事
【2023-02-05追記】続編 migamamo-bio.hatenablog.com
参考にさせていただいた記事
*1:平均が定まっているのでガンマ分布のもう一方のパラメータも定まります