バイオとインフォと

バイオとインフォで食べていきたい研究員の技術ブログ

NumPyroでMCMC - ロジスティックモデルを使った個体数増加のモデリング

概要

前から気になっていたNumpyroを試してみたので覚書。コードは以下のrepositoryにあります。
github.com

NumPyro↓ これを使うと手軽に確率モデリングのコードが書けるらしいです。
num.pyro.ai

お題

ロジスティックモデルにしたがう個体数増加のモデリングをやってみました。
ロジスティックモデルは、 Nを個体数、 tを時間として以下の微分方程式で表されるモデルです。
ロジスティック方程式 - Wikipedia

 \frac{dN}{dt}=rN(1-\frac{N}{K})

 rが個体数増加速度、 Kが環境収容力です。個体数 Nは指数関数的に増えますが、環境収容力 Kに近づくにつれ速度が抑えられます。菌や細胞を培養したときの挙動に近いです。

上の微分方程式を解くと以下の式になります。 N_0は個体数の初期値です。

 N = \frac{K}{1+(\frac{K}{N_0}-1)e^{-rt}}

今回のお題用に、平均が r = 1,  K = 10,  N_0 = 0.1のロジスティックモデルにしたがうデータを用意しました。
 0 \lt t \lt 10区間で、 \theta = 0.1のガンマ分布*1にしたがう100個のデータを生成しました。
NumPyroを使ったマルコフ連鎖モンテカルロMCMC)法により、データから、元のパラメータの復元を試みます。

生成されたデータと、ロジスティックモデルにしたがう期待値 ± ガンマ分布のSDのplot。

コードの説明

ライブラリの読み込み
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import jax
import numpyro
import numpyro.distributions as dist
import arviz as az
Toy dataの生成

上述の通りtoy dataを生成します。

def logistic(t, k, r, n0):
    n = k / (1. + (k / n0 - 1.) * np.exp(-r * t))
    return n

k = 10.
r = 1.
n0 = 0.1
scale = 0.1

t_max = 10.
n_data = 100

df_data = pd.DataFrame({ 't': np.linspace(0, t_max , n_data) })

np.random.seed(seed = 1)
df_data = df_data.assign(
        n = np.random.gamma(logistic(df_data['t'], k, r, n0) / scale, scale, n_data)
    )
モデルの記述

上述のデータの生成過程を記述します。事前分布はすべてコーシー or 半コーシー分布とします(無情報のつもり)。
データはnumpy arrayではなくjax numpy arrayとして扱う必要があります。ガンマ分布のパラメータの取り方がnumpyとnumpyroで違うので要注意です。

def logistic_jax(t, k, r, n0):
    n = k / (1. + (k / n0 - 1.) * jax.numpy.exp(-r * t))
    return n

def model(t, n, N):

    k = numpyro.sample('k', dist.HalfCauchy(10.))
    r = numpyro.sample('r', dist.Cauchy(10.))
    n0 = numpyro.sample('n0', dist.HalfCauchy(10.))
    scale = numpyro.sample('scale', dist.HalfCauchy(10.))

    n_mean = logistic_jax(t, k, r, n0)

    with numpyro.plate('data', N):
        numpyro.sample('n', dist.Gamma(n_mean / scale, 1. / scale), obs = n)
MCMCでサンプリング

MCMCによるサンプリングで事後分布を推定します。今回はNo-U-turn sampler (NUTS) を使います。Chainを4つ用意して並列化します。

iter_warmup = 200
iter_sample = 1000
n_chain = 4

numpyro.set_host_device_count(n_chain)
nuts_kernel = numpyro.infer.NUTS(model, adapt_step_size = True)

mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup = iter_warmup, num_samples = iter_sample,
                          num_chains = n_chain, chain_method = 'parallel')
mcmc.run(jax.random.PRNGKey(0), t = df_data['t'].values, n = df_data['n'].values, N = df_data.shape[0])
結果の確認

サンプリングされたパラメータ(事後分布)を確認します。

az.plot_trace(mcmc)

4つのchainが正解の値に収束しました。

サンプリングされたパラメータの統計量 (平均, SD等) と、MCMCの収束判定の指標である \hat{R}を確認します。

dict_diagnostics = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain = True), prob = 0.95, group_by_chain = True)
df_diagnostics = pd.DataFrame.from_dict(dict_diagnostics)

df_diagnostics
k n0 r scale
mean 10.202814 0.096435 1.015158 0.082979
std 0.183319 0.011746 0.033786 0.012336
median 10.197577 0.095709 1.015514 0.081736
2.5% 9.854436 0.073230 0.949016 0.060573
97.5% 10.567705 0.118404 1.080463 0.108148
n_eff 2063.584710 1622.019435 1583.550530 1993.970682
r_hat 1.002008 1.001410 1.001241 1.000152

サンプリングされたパラメータの値は正解の値にほぼ等しく、また、 \hat{R}がほぼ1であることからMCMCが収束していることが確認できました。

予測分布の生成(おまけ)

最後に、サンプリングされたパラメータの事後分布を使って、時間 tに対する個体数 Nの予測分布を生成します。
解析的には求まらないので、サンプリングされたパラメータ (1,000個×4 chains分) から Nの分布を作ってサンプリングします。

pred = numpyro.infer.Predictive(model, mcmc.get_samples(group_by_chain = False))

n_bin = 100
t_pred = np.linspace(0, t_max, n_bin)

n_pred = pred(jax.random.PRNGKey(0), t = t_pred, n = None, N = t_pred.shape[0])['n']

 tの値に対して4,000個の Nがサンプリングされました。その平均とSDを元のデータに重ねてみます。

n_pred_mean = n_pred.mean(0)
n_pred_std = n_pred.std(0)

sns.scatterplot(x = 't', y = 'n', data = df_data)
sns.lineplot(x = t_lineplot, y = n_pred_mean)
plt.fill_between(x = t_lineplot,
                 y1 = n_pred_mean - n_pred_std,
                 y2 = n_pred_mean + n_pred_std,
                 alpha = 0.2)

きれいに重なりました!

おわりに

昔stanでベイジアンモデリングの真似事をやったことがありますが、当時に比べるとかなり楽に書けた印象です。Pythonで手軽に書けるのはありがたいです。
今回はGPUが使えずJAXの恩恵にあずかれませんでした。今度dockerでGPUを使う方法を調べてみようと思います。

関連記事

【2023-02-05追記】続編 migamamo-bio.hatenablog.com

参考にさせていただいた記事

blog.deepblue-ts.co.jp

zenn.dev

*1:平均が定まっているのでガンマ分布のもう一方のパラメータ kも定まります