「大人の教養・知識・気付き」を伸ばすブログ

一流の大人(ビジネスマン、政治家、リーダー…)として知っておきたい、教養・社会動向を意外なところから取り上げ学ぶことで“気付く力”を伸ばすブログです。データ分析・語学に力点を置いています。 →現在、コンサルタントの雛になるべく、少しずつ勉強中です(※2024年1月21日改訂)。

MENU

Juliaを使ってみる(14/22):ラプラス近似(その1)

今回やりたいこと

 勾配を用いた効率的な事後分布のシミュレーションを行う。

を参照する。

1. 勾配を用いる必要性

 対数事後分布の勾配を計算するのが今後のアプローチのコアである。

1.1 なぜ勾配を用いるのか

 伝承サンプリングや数値積分を用いた統計モデルの事後分布計算では、いくつかのデメリットを有する。

  • (数値積分において)推定対象パラメータの数が増えるとその推定のための計算量が膨大になり、実用上、計算不可能になり得る。
  • (数値積分において)計算に用いる各点に対して開始時点と終了時点を与える必要があり、それらが適切に設定できる保証もない。
  • (伝承サンプリングにおいて)観測データに合致した標本が得づらくなる。


以上から、数値積分や伝承サンプリングでは限界がある。

1.2 勾配を用いた計算の効率化

 歴史上、\mathrm{Bayes}統計が実戦で応用されるまでに計算機やアルゴリズムの発展を待つ必要があったのはこのような計算効率の問題があったためであった。これに対して、近年では\mathrm{Markov}連鎖\mathrm{Monte\ Carlo}法や\mathrm{Laplace}近似、逐次\mathrm{Monte\ Carlo}法や変分推論などの計算手法が登場し、現実的な推論計算手段を提供するようになってきた。これらは「無駄な計算を省き、重要な計算のみ重点的に行う」というアイディアの下で開発されてきた。
 

 目的とな事後分布について勾配を計算すると、勾配は各時点で事後分布の密度が高い場所の向きを表し、標本が発生しやすい方向を選択することができる。
 事後分布の勾配を利用することでサンプリングを効率化できる。勾配が分かっていれば、「勾配の矢印が向いている先の領域においてより多くの標本を発生させ、逆に向いていない領域からはあまり標本を発生させない」ことによってそうした勾配=分布に従う標本を生成できそうである。ただしただただ密度の高い部分からサンプリングしたのでは真の事後分布からサンプリングしたとは言い難い。\mathrm{Hamilton} \mathrm{Monte} \mathrm{Carlo}法はそのバランスを上手く取るようにしている。

 もう1つの方法として、事後分布をより扱いやすい分布を用いて近似・簡略化する方法がある。近似手法はサンプリングよりも高速であるものの、正規分布などの(極度に)簡単な分布に簡略化するために、事後分布の近似能力は低く(=粗く)なることが多い。

1.2.1 線形回帰

 まずは線形回帰y=w_1 x+w_2の推定を考える。ただし簡単のためにw_2=0とする。
 w_1の事後分布を考える。



\begin{aligned}
P\left\{w_1|\boldsymbol{Y},\boldsymbol{X},w_2\right\}&=\displaystyle{\frac{P\left\{\boldsymbol{Y},w_1|\boldsymbol{X},w_2\right\}}{P\left\{\boldsymbol{Y}|\boldsymbol{X},w_2\right\}}}\\
&=\displaystyle{\frac{P\left\{\boldsymbol{Y},w_1|\boldsymbol{X},w_2\right\}}{\displaystyle{\int p(\boldsymbol{Y},w_1|\boldsymbol{X},w_2)}dw_1}}
\end{aligned}


w_1の事後分布が与えられる。
 問題は、モデルが複雑になった時の分母部分の積分計算である。
\mathrm{Laplace}近似では、分母の周辺尤度を直接扱わずに計算できる。周辺尤度はw_1に依存するのみの定数と見なすことができるから、事後分布の確率密度関数w_1の関数であることを考慮すれば、次のように事後分布は分子のみに比例すると考えることができる。すなわち



\begin{aligned}
P\{w_1|\boldsymbol{X},\boldsymbol{Y},w_2\}\propto P\{\boldsymbol{Y},w_1|\boldsymbol{X},w_2\}
\end{aligned}


が成り立つ。この式の右辺はモデルの同時分布そのものであり、対数を取ることで



\begin{aligned}
\log P\{w_1|\boldsymbol{X},\boldsymbol{Y},w_2\}&=\log P\{\boldsymbol{Y}|\boldsymbol{X},w_1,w_2\}+\log P\{w_1\}+C\\
&=\displaystyle{\sum_{n=1}^{N} \log P\{y_n|x_n,w_1,w_2\}}+\log P\{w_1\}+C
\end{aligned}


と書き換えることができ、これは非正規化対数事後分布と呼ばれる。

using Distributions, PyPlot, ForwardDiff, LinearAlgebra

# n次元単位行列
eye(n) = Diagonal{Forward64}(I, n)

# パラメータ抽出用の関数
unzip(a) = map(x -> getfield(a, x), fieldnames(eltype(a)))

# グラフの諸設定を行う関数
function set_options(ax, xlabel, ylabel, title;
                      grid = true, gridy = false, legend = false)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    if grid
        if gridy
            ax.grid(axis = "y")
        else
            ax.grid()
        end
    end
    legend && ax.legend()
end


### Laplace近似
# (1) 勾配法によって事後分布の極大値を1つ求め、近似用の正規分布の平均とする
# (2) 求めた極大値において2回微分が一致するように近似用の正規分布の分散を与える

X_obs = [-2,1,5]
Y_obs = [-2.2,-1.0,1.5]

# まずは簡単にy=w₁xを推定する
σ = 1.0 # 誤差項の標準偏差

# 事前分布の平均値と標準偏差
μ₁ = 0.0
σ₁ = 10.0

w₂ = 0.0

ulp(w₁) = sum(logpdf.(Normal.(w₁*X_obs .+ w₂,σ), Y_obs)) + logpdf(Normal(μ₁, σ₁), w₁)

w₁s = range(-5, 5, length = 100)

fig,axes = subplots(1, 2, figsize = (8,4))

# 
axes[1].plot(w₁s, ulp.(w₁s))
set_options(axes[1], "w₁","log density (unnormalised)", "unnormalised log posterior distribution")


axes[2].plot(w₁s, exp.(ulp.(w₁s)))
set_options(axes[2], "w₁","log density (unnormalised)", "unnormalised log posterior distribution")

tight_layout()


 これに対して、\mathrm{Laplace}近似では近似用の正規分布を用いる。

  • 勾配法により事後分布の極大値を1つ求め、近似用の正規分布の平均とする。
function gradient_method_1dim(f, x_init, η, maxiter)
    f′(x) = ForwardDiff.derivative(f, x)
    x_seq = Array{typeof(x_init), 1}(undef, maxiter)
    x_seq[1] = x_init
    
    for i in 2:maxiter
        x_seq[i] = x_seq[i-1] + η*f′(x_seq[i-1])
    end
    
    x_seq
end


# 最適化パラメータ
w₁_init = 0.0
maxiter = 100
η = 0.01


# 最適化の実施
w₁_seq = gradient_method_1dim(ulp, w₁_init, η, maxiter)


# 勾配法の過程を可視化
fig, ax = subplots(figsize = (8,4))
ax.plot(w₁_seq)
set_options(ax, "iteration","w₁", "w₁ sequence")

プライバシーポリシー お問い合わせ