バイオとインフォと

バイオとインフォで食べていきたい研究員の技術ブログ

Numpyroを使った常微分方程式のパラメータ推定 - Lotka-Volterra方程式

概要

Numpyroで微分方程式のパラメータ推定をやってみました。コードは前回のNumPuroデビュー記事のコードと同じrepositoryにあります。 github.com

前回のNumPyroデビュー記事はこちら。 migamamo-bio.hatenablog.com

調べてみたところ、公式のNumpyro応用例の中で微分方程式を扱っている例を見つけました。これを参考にコードを書いていきます。
num.pyro.ai

お題

今回のお題は、数理生物学界隈で有名なLotka-Volterra equation ("predator-prey model") です。
Lotka-Volterra方程式は、捕食者 (predator) と被食者 (prey) の個体数を表現する以下の2つの常微分方程式です。
ロトカ・ヴォルテラの方程式 - Wikipedia

 \frac{du}{dt}=\alpha u - \beta uv

 \frac{dv}{dt}=\delta uv - \gamma v

 tが時間、 uがpreyの個体数、 vがpredatorの個体数、 \alpha, \beta, \gamma, \deltaが正の実数値をとるパラメータです。Preyの個体数 uは指数関数的に増えますが、predatorとの遭遇確率 uvに応じて減少します。Predatorの個体数 vは指数関数的に減りますが、preyとの遭遇確率 uvに応じて増加します。Lotka-Volterra方程式は解析的に解けないので、数値的に解きながらパラメータ推定を行います。

今回は、Numpyroに含まれるテスト用データLYNXHAREを使います。中身は1845~1935年のCanada lynx*1 (predator) とsnowshoe hare*2 (prey) の個体数 (in thousands) とされているデータ*3です。

コードの説明

ライブラリの読み込み

ライブラリを読み込みます。ScipyではなくJAXからodeintを読み込みます。

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import jax
import numpyro
import numpyro.distributions as dist
import arviz as az

from numpyro.examples.datasets import LYNXHARE, load_dataset
from jax.experimental.ode import odeint
データの読み込み
_, fetch = load_dataset(LYNXHARE, shuffle=False)
t, n = fetch()
モデルの記述

モデルを記述します。

def dn_dt(n, t, n_eq, theta):

    u = n[0]
    v = n[1]

    v_eq = n_eq[0] # v_eq = alpha / beta
    u_eq = n_eq[1] # u_eq = gamma / delta

    beta = theta[0]
    delta = theta[1]

    du_dt = (v_eq - v) * beta * u
    dv_dt = (u - u_eq) * delta * v

    return jax.numpy.stack([du_dt, dv_dt])

def model(t, n, N):

    n_eq = numpyro.sample("n_eq", dist.LogNormal(jax.numpy.log(30.), 0.5).expand([2]))
    theta = numpyro.sample('theta', dist.HalfNormal(0.1).expand([2]))

    n0 = numpyro.sample("n0", dist.LogNormal(jax.numpy.log(30.), 0.5).expand([2]))
    n_mean = odeint(dn_dt, n0, t, n_eq, theta, rtol = 1e-6, atol = 1e-5, mxstep = 1000)

    scale = numpyro.sample("scale", dist.HalfCauchy(10.).expand([2]))
    numpyro.sample("n", dist.Gamma(n_mean / scale, 1. / scale), obs = n)

JAXのodeint (Scipyのodeintではない) を使ってdn_dtを数値的に解きながらパラメータをサンプリングします。

公式のNumpyro応用例に載っているコードとパラメータの取り方を変えました。n_eqはLotka-Volterra方程式の平衡点(predatorとpreyの個体数が均衡して変化しない点)です。平衡点 (\bar u, \bar v)は、 \bar u = \frac{\gamma}{\delta},  \bar v = \frac{\alpha}{\beta}です。
個体数関連のパラメータ (数値積分の初期値n0n_eq =  (\bar u, \bar v)) は10~100くらいの範囲のはずなので、事前分布を対数正規分布 (log30, 0.5) とします。微分方程式のパラメータ ( \beta, \delta) の事前分布は半コーシー分布 (0, 0.1) とします。元データの雰囲気からprey (predator) の増加 (減少) 速度は1年あたり1桁倍くらいのように見えるので、 \alpha = \beta \bar v,  \gamma = \delta \bar uの値は1前後くらいのはず、したがって \beta, \deltaは0.01 ~ 0.1くらいではないかという目算です。また、データはガンマ分布にしたがって生成されることとします。

MCMC

今回はNo-U-turn sampler (NUTS) を使います。Chainを4つ用意して並列化します。

iter_warmup = 1000
iter_sample = 1000
n_chain = 4

t_scaled = t.astype(float) - min(t.astype(float))

numpyro.set_host_device_count(n_chain)
nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size = True, init_strategy = numpyro.infer.init_to_sample())

mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup = iter_warmup, num_samples = iter_sample,
                          num_chains = n_chain, chain_method = 'parallel')

今回は初期値を事前分布からサンプリングします (init_strategy = numpyro.infer.init_to_sample()の部分)。デフォルトの初期値の選び方だとMCMCが収束しませんでした。マニュアルを読むと、デフォルトでは事前分布ではなく一様分布*4から初期値がサンプリングされるそうです。

MCMC warmup

Warmupのみ実行して各パラメータの動きを確認します。
今回は複数のパラメータをarrayとしてまとめて定義しました。各要素は別の色で表示されます。青が[0]の位置の要素、オレンジが[1]の位置の要素です。
backend_kwargs = {'constrained_layout': True} としないと前回記事のように図の見た目が崩れます。

mcmc.warmup(jax.random.PRNGKey(0), t = t_scaled, n = n, N = n.shape[0],  collect_warmup = True)
az.plot_trace(mcmc, backend_kwargs = {'constrained_layout': True})

若干苦しそうですが1,000 iterationsで収束してくれました。

MCMC sampling

事後分布をサンプリングしていきます。

mcmc.run(jax.random.PRNGKey(0), t = t_scaled, n = n, N = n.shape[0])
az.plot_trace(mcmc, backend_kwargs = {'constrained_layout': True})

サンプリングされたパラメータの統計量 (平均, SD等) と、MCMCの収束判定の指標である \hat{R}を確認します。

dict_diagnostics = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain = True), prob = 0.95, group_by_chain = True)
df_diagnostics = pd.DataFrame.from_dict(dict_diagnostics)

df_diagnostics
n0 n_eq scale theta
mean [65.87, 36.49] [27.74, 43.62] [28.37, 5.25] [0.015, 0.026]
std [8.02, 4.75] [1.26, 3.47] [4.70, 0.84] [0.0038, 0.0086]
median [65.59, 36.32] [27.67, 43.53] [27.87, 5.18] [0.015, 0.024]
2.5% [49.67, 27.80] [25.36, 36.75] [19.90, 3.77] [0.0078, 0.015]
97.5% [80.84, 46.07] [30.26, 50.26] [37.86, 6.97] [0.023, 0.043]
n_eff [1438.68, 2273.09] [2315.89, 2066.26] [2259.36, 2527.52] [1426.17, 904.76]
r_hat [1.0016, 1.0019] [1.0024, 1.0002] [1.0007, 1.0013] [1.0019, 1.0034]

 \hat{R}がほぼ1であることからMCMCが収束したことが確認できました。

予測分布の生成

サンプリングされたパラメータの事後分布を使って、時間 tに対するpreyの個体数 uとpredatorの個体数 vの予測分布を生成します。
解析的には求まらないので、サンプリングされたパラメータ (1,000個×4 chains分) から u, vの分布を作ってサンプリングします。

pred = numpyro.infer.Predictive(model, mcmc.get_samples(group_by_chain = False))

n_bin = 1000
t_pred_scaled = np.linspace(0, max(t_scaled), n_bin)
t_pred = t_pred_scaled + min(t)

n_pred = pred(jax.random.PRNGKey(0), t = t_pred_scaled, n = None, N = t_pred_scaled.shape[0])['n']

 tの値に対して4,000個の u, vがサンプリングされました。その平均とSDを元のデータに重ねてみます。

n_pred_mean = n_pred.mean(0)
n_pred_std = n_pred.std(0)

sns.scatterplot(x = t, y = n[:, 0], label = 'prey')
sns.lineplot(x = t, y = n[:, 0], linestyle = 'dotted')
sns.scatterplot(x = t, y = n[:, 1], label = 'predator')
sns.lineplot(x = t, y = n[:, 1], linestyle = 'dotted')
sns.lineplot(x = t_pred, y = n_pred_mean[:, 0], label = 'prey (predicted)')
sns.lineplot(x = t_pred, y = n_pred_mean[:, 1], label = 'predator (predicted)')
plt.fill_between(x = t_pred,
                 y1 = n_pred_mean[:, 0] - n_pred_std[:, 0],
                 y2 = n_pred_mean[:, 0] + n_pred_std[:, 0],
                 alpha = 0.2)
plt.fill_between(x = t_pred,
                 y1 = n_pred_mean[:, 1] - n_pred_std[:, 1],
                 y2 = n_pred_mean[:, 1] + n_pred_std[:, 1],
                 alpha = 0.2)

期待通りの周期的なパターンが出てきました。おそらくうまくいったと思われます。

おわりに

NumPyro、個人的には使い勝手が良いと感じました。NumpyやScipyの使う感覚の延長で微分方程式の確率モデリングができるのはありがたいです。化学反応速度論や感染症モデリングなど、微分方程式を使った確率モデリング *5 のハードルが下がりそうです。

この記事を書き始める前に公式のNumPyro応用例に載っているコードを走らせてみたところ、乱数シード値依存で収束先が変わったり、事前分布が若干恣意的に見えたり等、改善の余地がある印象を受けました。
そこでこの記事では公式チュートリアルのコードを少し改変し、パラメータの取り方等を変更してみました。しかしそれでもまだMCMCの初期値依存で収束先が変わりそうな雰囲気があります。ベイズ難しい......

また、今回はCPUを使ったためMCMCの待ち時間が数分生じました。せっかくJAXを使っているのでGPUの恩恵にあずかりたいです。近々dockerでGPUを使う方法を調べたいです。

関連記事

前回記事 migamamo-bio.hatenablog.com

参考にさせていただいた記事

num.pyro.ai

toeming.hatenablog.com