Numpyroを使った常微分方程式のパラメータ推定 - Lotka-Volterra方程式
概要
Numpyroで微分方程式のパラメータ推定をやってみました。コードは前回のNumPuroデビュー記事のコードと同じrepositoryにあります。 github.com
前回のNumPyroデビュー記事はこちら。 migamamo-bio.hatenablog.com
調べてみたところ、公式のNumpyro応用例の中で微分方程式を扱っている例を見つけました。これを参考にコードを書いていきます。
num.pyro.ai
お題
今回のお題は、数理生物学界隈で有名なLotka-Volterra equation ("predator-prey model") です。
Lotka-Volterra方程式は、捕食者 (predator) と被食者 (prey) の個体数を表現する以下の2つの常微分方程式です。
ロトカ・ヴォルテラの方程式 - Wikipedia
が時間、がpreyの個体数、がpredatorの個体数、が正の実数値をとるパラメータです。Preyの個体数は指数関数的に増えますが、predatorとの遭遇確率に応じて減少します。Predatorの個体数は指数関数的に減りますが、preyとの遭遇確率に応じて増加します。Lotka-Volterra方程式は解析的に解けないので、数値的に解きながらパラメータ推定を行います。
今回は、Numpyroに含まれるテスト用データLYNXHARE
を使います。中身は1845~1935年のCanada lynx*1 (predator) とsnowshoe hare*2 (prey) の個体数 (in thousands) とされているデータ*3です。
コードの説明
ライブラリの読み込み
ライブラリを読み込みます。ScipyではなくJAXからodeintを読み込みます。
import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt %matplotlib inline import jax import numpyro import numpyro.distributions as dist import arviz as az from numpyro.examples.datasets import LYNXHARE, load_dataset from jax.experimental.ode import odeint
データの読み込み
_, fetch = load_dataset(LYNXHARE, shuffle=False)
t, n = fetch()
モデルの記述
モデルを記述します。
def dn_dt(n, t, n_eq, theta): u = n[0] v = n[1] v_eq = n_eq[0] # v_eq = alpha / beta u_eq = n_eq[1] # u_eq = gamma / delta beta = theta[0] delta = theta[1] du_dt = (v_eq - v) * beta * u dv_dt = (u - u_eq) * delta * v return jax.numpy.stack([du_dt, dv_dt]) def model(t, n, N): n_eq = numpyro.sample("n_eq", dist.LogNormal(jax.numpy.log(30.), 0.5).expand([2])) theta = numpyro.sample('theta', dist.HalfNormal(0.1).expand([2])) n0 = numpyro.sample("n0", dist.LogNormal(jax.numpy.log(30.), 0.5).expand([2])) n_mean = odeint(dn_dt, n0, t, n_eq, theta, rtol = 1e-6, atol = 1e-5, mxstep = 1000) scale = numpyro.sample("scale", dist.HalfCauchy(10.).expand([2])) numpyro.sample("n", dist.Gamma(n_mean / scale, 1. / scale), obs = n)
JAXのodeint
(Scipyのodeint
ではない) を使ってdn_dt
を数値的に解きながらパラメータをサンプリングします。
公式のNumpyro応用例に載っているコードとパラメータの取り方を変えました。n_eq
はLotka-Volterra方程式の平衡点(predatorとpreyの個体数が均衡して変化しない点)です。平衡点は、, です。
個体数関連のパラメータ (数値積分の初期値n0
とn_eq
= ) は10~100くらいの範囲のはずなので、事前分布を対数正規分布 (log30, 0.5) とします。微分方程式のパラメータ () の事前分布は半コーシー分布 (0, 0.1) とします。元データの雰囲気からprey (predator) の増加 (減少) 速度は1年あたり1桁倍くらいのように見えるので、, の値は1前後くらいのはず、したがっては0.01 ~ 0.1くらいではないかという目算です。また、データはガンマ分布にしたがって生成されることとします。
MCMC
今回はNo-U-turn sampler (NUTS) を使います。Chainを4つ用意して並列化します。
iter_warmup = 1000 iter_sample = 1000 n_chain = 4 t_scaled = t.astype(float) - min(t.astype(float)) numpyro.set_host_device_count(n_chain) nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size = True, init_strategy = numpyro.infer.init_to_sample()) mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup = iter_warmup, num_samples = iter_sample, num_chains = n_chain, chain_method = 'parallel')
今回は初期値を事前分布からサンプリングします (init_strategy = numpyro.infer.init_to_sample()
の部分)。デフォルトの初期値の選び方だとMCMCが収束しませんでした。マニュアルを読むと、デフォルトでは事前分布ではなく一様分布*4から初期値がサンプリングされるそうです。
MCMC warmup
Warmupのみ実行して各パラメータの動きを確認します。
今回は複数のパラメータをarrayとしてまとめて定義しました。各要素は別の色で表示されます。青が[0]の位置の要素、オレンジが[1]の位置の要素です。
backend_kwargs = {'constrained_layout': True}
としないと前回記事のように図の見た目が崩れます。
mcmc.warmup(jax.random.PRNGKey(0), t = t_scaled, n = n, N = n.shape[0], collect_warmup = True) az.plot_trace(mcmc, backend_kwargs = {'constrained_layout': True})
若干苦しそうですが1,000 iterationsで収束してくれました。
MCMC sampling
事後分布をサンプリングしていきます。
mcmc.run(jax.random.PRNGKey(0), t = t_scaled, n = n, N = n.shape[0]) az.plot_trace(mcmc, backend_kwargs = {'constrained_layout': True})
サンプリングされたパラメータの統計量 (平均, SD等) と、MCMCの収束判定の指標であるを確認します。
dict_diagnostics = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain = True), prob = 0.95, group_by_chain = True) df_diagnostics = pd.DataFrame.from_dict(dict_diagnostics) df_diagnostics
n0 | n_eq | scale | theta | |
---|---|---|---|---|
mean | [65.87, 36.49] | [27.74, 43.62] | [28.37, 5.25] | [0.015, 0.026] |
std | [8.02, 4.75] | [1.26, 3.47] | [4.70, 0.84] | [0.0038, 0.0086] |
median | [65.59, 36.32] | [27.67, 43.53] | [27.87, 5.18] | [0.015, 0.024] |
2.5% | [49.67, 27.80] | [25.36, 36.75] | [19.90, 3.77] | [0.0078, 0.015] |
97.5% | [80.84, 46.07] | [30.26, 50.26] | [37.86, 6.97] | [0.023, 0.043] |
n_eff | [1438.68, 2273.09] | [2315.89, 2066.26] | [2259.36, 2527.52] | [1426.17, 904.76] |
r_hat | [1.0016, 1.0019] | [1.0024, 1.0002] | [1.0007, 1.0013] | [1.0019, 1.0034] |
がほぼ1であることからMCMCが収束したことが確認できました。
予測分布の生成
サンプリングされたパラメータの事後分布を使って、時間に対するpreyの個体数とpredatorの個体数の予測分布を生成します。
解析的には求まらないので、サンプリングされたパラメータ (1,000個×4 chains分) からの分布を作ってサンプリングします。
pred = numpyro.infer.Predictive(model, mcmc.get_samples(group_by_chain = False)) n_bin = 1000 t_pred_scaled = np.linspace(0, max(t_scaled), n_bin) t_pred = t_pred_scaled + min(t) n_pred = pred(jax.random.PRNGKey(0), t = t_pred_scaled, n = None, N = t_pred_scaled.shape[0])['n']
各の値に対して4,000個のがサンプリングされました。その平均とSDを元のデータに重ねてみます。
n_pred_mean = n_pred.mean(0) n_pred_std = n_pred.std(0) sns.scatterplot(x = t, y = n[:, 0], label = 'prey') sns.lineplot(x = t, y = n[:, 0], linestyle = 'dotted') sns.scatterplot(x = t, y = n[:, 1], label = 'predator') sns.lineplot(x = t, y = n[:, 1], linestyle = 'dotted') sns.lineplot(x = t_pred, y = n_pred_mean[:, 0], label = 'prey (predicted)') sns.lineplot(x = t_pred, y = n_pred_mean[:, 1], label = 'predator (predicted)') plt.fill_between(x = t_pred, y1 = n_pred_mean[:, 0] - n_pred_std[:, 0], y2 = n_pred_mean[:, 0] + n_pred_std[:, 0], alpha = 0.2) plt.fill_between(x = t_pred, y1 = n_pred_mean[:, 1] - n_pred_std[:, 1], y2 = n_pred_mean[:, 1] + n_pred_std[:, 1], alpha = 0.2)
期待通りの周期的なパターンが出てきました。おそらくうまくいったと思われます。
おわりに
NumPyro、個人的には使い勝手が良いと感じました。NumpyやScipyの使う感覚の延長で微分方程式の確率モデリングができるのはありがたいです。化学反応速度論や感染症のモデリングなど、微分方程式を使った確率モデリング *5 のハードルが下がりそうです。
この記事を書き始める前に公式のNumPyro応用例に載っているコードを走らせてみたところ、乱数シード値依存で収束先が変わったり、事前分布が若干恣意的に見えたり等、改善の余地がある印象を受けました。
そこでこの記事では公式チュートリアルのコードを少し改変し、パラメータの取り方等を変更してみました。しかしそれでもまだMCMCの初期値依存で収束先が変わりそうな雰囲気があります。ベイズ難しい......
また、今回はCPUを使ったためMCMCの待ち時間が数分生じました。せっかくJAXを使っているのでGPUの恩恵にあずかりたいです。近々dockerでGPUを使う方法を調べたいです。
関連記事
前回記事 migamamo-bio.hatenablog.com
参考にさせていただいた記事
*3:所説あります: Introduction to Mathematical Modeling, Whitman College
*4:どの範囲の一様分布かは調べていません
*5:例えば SIRモデル - Wikipedia