Kalman filter style recursion to marginalize state variables to speed up Stan inference

Summary: Technical stuff about speeding up Bayesian inference performed with Stan, for a certain random-walk model that contains a conditionally linear-Gaussian state-space model. The basic idea is that we can write a Kalman-filter style loop within the Stan code so that the log-posterior evaluated is one where the state variables are marginalized out. The HMC sampler then needs to move only in the space of the parameters. Complete Stan codes are in the accompanying GitHub repository.

Acknowledging prior work

The main idea is not novel -- I've long time ago seen a Kalman filter within Stan somewhere, perhaps this one by Jeffrey Arnold (https://gist.github.com/jrnold/4700387). This possibility was also long time ago once discussed in a group meeting (at least Aki Vehtari and Simo Särkkä were present, I don't recall who else). I also don't recall which of these long times is longer.

Introduction

This post is inspired by a state-space model by James Savage for aggregating presidential election polls, discussed in Andrew Gelman's blog. Here, I consider the following model, which is a simplification but contains the essential parts for demonstrating how the aforementioned poll model could be rewritten to marginalize the state variables. Sampler performance (say, effective samples per second) should improve (as demonstrated by the experiment in this post). A drawback is that the .stan file loses the property of double-acting as an easily-readable model specification.

\begin{align*} x_1 & \sim \mathrm{N}(m_0, P_0), \\ x_k & = x_{k-1} + \epsilon^x_k, ~~ \epsilon^x_k \sim \mathrm{N}(0,Q), \\ z_k & = x_k + \epsilon^z_k, ~~ \epsilon^z_k \sim \mathrm{N}(0,\sigma_z^2), \\ y_k & = z_k + \epsilon^y_k, ~~ \epsilon^y_k \sim \mathrm{N}(0,\sigma_{y,k}^2) \end{align*}

for \(k=1,2,\ldots,T\). The noise components (\(\epsilon^x,\epsilon^z,\epsilon^y\)) and the initial state \(x_1\) are mutually independent. The idea is that the phenomenon-of-interest follows a latent trend (\(x\)), which is modeled by a Gaussian random walk. Besides the trend, there is some specific-to-time-instant jitter in the phenomenon-of-interest (so the phenomenon-of-interest is actually \(z\)). Finally, this \(z\) is observed via noisy measurements \(y\).

So, we observe \(y\) and are interested in infering the posterior distribution of \(z\) (and \(x\) and the parameters). \(m_0\) and \(P_0\) are given fixed values reflecting our information about the initial state. The noise standard deviations of the state process (\(\sqrt{Q}\) and \(\sigma_z\)) are given half-Student-t priors (with mode at 0 and 2 d.o.f.). The measurement noise parameters \(\sigma_{y,k}\) are assumed to be known properties of the measurement devices and thus given as part of data. See randomwalk_naive.stan for a readable Stan implementation containing the model written 'as is'.

Kalman filter version

This section describes wrinting a simple Kalman filter evaluating the marginal log-posterior \(p(\sigma_z,Q \mid \mathrm{data})\) in Stan. First, note that \(z\) can be marginalized out of our state-space model to obtain

\begin{align*} x_1 & \sim \mathrm{N}(m_0,P_0), \\ x_k & = x_{k-1} + \epsilon^x_k, \epsilon^x_k \sim \mathrm{N}(0,Q), \\ y_k & = x_k + r_k, r_k=\epsilon^y_k+\epsilon^z_k, r_k \sim \mathrm{N}(0,\sigma_z^2 + \sigma_(y,k)^2). \end{align*}

Following the Kalman filter representation in, e.g., [1, p. 73], and simplifying for the facts that this is a scalar model where the process is a random walk (\(A=1\)) and the measurement is state plus noise (\(H=1\)), one Kalman filter step (prediction and update) is

\begin{align*} m^{-}_k &= m_k, \\ P^{-}_k &= P_k + Q, \\ S_k &= P^{-}_k + R_k, \\ K_k &= P^{-}_k/S_k, \\ m_k &= m^{-}_k + K_k\,(y_k - m^{-}_k), \\ P_k &= P^{-}_k + K_k\,S_k\,K_k, \end{align*}

where \(m^{-}_k,P^{-}_k\) are the parameters of the predicted distribution \(p(x_k\mid y_{1:k-1})\), \(m_k,P_k\) are the parameters of the filtering distribution \(p(x_k \mid y_{1:k})\), \(S_k\) is the innovation variance \(\mathrm{Var}[y_k \mid y_{1:k-1}]\), and \(R_k = \sigma_z^2 + \sigma_{y,k}^2\).

And here is the aforementioned filter written as Stan code:

vector[T] m; //filtered mean x_t|y_1:t
vector[T] P; //filtered var x_t|y_1:t
vector[T] R; //measurement variance,  y_t|x_t
vector[T] S; //measurement variance, y_t|y_1:t-1
vector[T] m_pred; //predicted mean x_t|y_1:t-1
vector[T] P_pred; //predicted var x_t|y_1:t-1
real K;  //Kalman gain, depends on t but not stored

//Filtering
R = sigma_y .* sigma_y + sigma_z*sigma_z;

m_pred[1] = m0;
P_pred[1] = P0;

for (t in 1:T) {
    //Prediction step
    if (t>1) {
        m_pred[t] = m[t-1];
        P_pred[t] = P[t-1] + sqrtQ*sqrtQ;
    }

    //Update step
    S[t] = P_pred[t] + R[t];
    K = P_pred[t]/S[t];
    m[t] = m_pred[t] + K*(y[t] - m_pred[t]);   //The measurement is just noise added to signal so mu==m_pred
    P[t] = P_pred[t] - K*S[t]*K;
}

To evaluate the contribution to the log-posterior, we may use the stored predicted means innovation variances as follows:

for (t in 1:T) {
    y[t] ~ normal(m_pred[t], sqrt(S[t])); //Should vectorize this but I didn't get how to do a vectorized sqrt :(
}

Sampling the states

With the work so far, we obtain an optimized Stan sampler for \(p(\sigma_z,\sqrt{Q} \mid \mathrm{data})\). However, the actual variables of interest might be the states \(x\) and \(z\). It is possible to obtain closed form Gaussian distributions for their marginal distributions conditional on the data and the parameters (Rauch-Tung-Striebel smoother, e.g., [1]). Here I however decided to use a method that generates samples of these variables, so that the output is more directly comparable with the output of the 'naive' version. The algorithm is based on a backward recursion for sampling the \(x_{1:T}\) and then sampling the \(z_{1:T}\) conditional on everything else. Note that since these samples do not impact the log-posterior any more, this shall be performed in the generated quantities block.

Backward pass for generating a sample of x

Probably well-known but I do not recall a reference (this is almost like a Rauch-Tung-Striebel smoother, but a sample is drawn at every step). We sample from \(x_{1:T}\mid y_{1:T},\mathrm{Parameters}\) by first sampling from \(x_T \mid y_{1:T},\mathrm{Params}\) which is just \(\mathrm{N}(m_T,P_T)\), then from \(x_{T-1} \mid x_T, y_{1:T},\mathrm{Params}\), \(x_{T-2} \mid x_{T-1},x_T, y_{1:T},\mathrm{Params}\) etc.

So, the conditional distributions to sample from are (everything is dependent on \(\mathrm{Params}\) left implicit from now on)

\begin{equation*} p(x_k \mid x_{k+1:T},y_{1:T}). \end{equation*}

Due to the conditional independence properties of the model, this simplifies to

\begin{equation*} = p(x_k \mid x_{k+1},y_{1:k},\mathrm{params}) \propto_{x_k} p(x_k\,x_{k+1} \mid y_{1:k}) = p(x_k \mid y_{1:k})\,p(x_{k+1}\mid x_k). \end{equation*}

The first term is just \(\mathrm{N}(x_k; m_k, P_k)\) and the second is \(\mathrm{N}(x_{k+1}; x_k, Q)\), so this is just the univariate Gaussian prior and Gaussian measurement noise model [2, p. 46-47]. So, as Stan code, the backward pass is

x[T] = normal_rng(m[T], sqrt(P[T]));
for (i in 1:T-1) { //It seems Stan does not support a decreasing loop index(?)
    int t;
    real varx;
    real meanx;
    t = T-i;
    varx = 1 / (1/P[t] + 1/(sqrtQ*sqrtQ));
    meanx = (m[t]/P[t] + x[t+1]/(sqrtQ*sqrtQ))*varx;
    x[t] = normal_rng(meanx,sqrt(varx));
}

Sampling z

Finally, we desire a sample from \(p(z_{1:T} \mid y_{1:T}, x_{1:T}, \mathrm{Params})\). By the conditional independence properties of the model, this equals

\begin{equation*} \prod_{k=1}^T p(z_k \mid y_k,x_k, \mathrm{Params}). \end{equation*}

and for each \(k\) we again have the basic Gaussian-Gaussian situation, so this is simply

for (t in 1:T) {
    real meanz;
    real varz;
    varz = 1/ (1/(sigma_z*sigma_z) + 1/(sigma_y[t]*sigma_y[t]));
    meanz = varz * (x[t]/(sigma_z*sigma_z) + y[t]/(sigma_y[t]*sigma_y[t]));
    z[t] = normal_rng(meanz,sqrt(varz));
}

Experiment

The complete stan files are in the GitHub repository.

I simulated a trajectory with 100 observations (see generatedata.py for details), and run both the 'naive' version and the Kalman-filter version for 100,000 steps (see the repository for full replication instructions). The computations were performed on my cheap laptop that I bought 2 years ago, Windows, so I can only hope the results are somewhat insensitive to details of the computing environment (upto a constant factor in runtimes). Please test on some more professional and controlled setup. For example, I have no idea whether there were some surprise background processes eating resources from one of the samplers.

Selected parts of the summary of the run with the 'naive' version:

Inference for Stan model: randomwalk_naive_model
1 chains: each with iter=(100000); warmup=(0); thin=(1); 100000 iterations saved.

Warmup took (6.3) seconds, 6.3 seconds total
Sampling took (477) seconds, 8.0 minutes total

                    Mean     MCSE   StdDev        5%    50%   95%  N_Eff  N_Eff/s  R_hat

sigma_z         4.3e-001 5.9e-003 1.5e-001  1.6e-001   0.44  0.67    665 1.4e+000    1.0
sqrtQ           8.4e-002 1.8e-003 4.8e-002  3.2e-002  0.072  0.18    725 1.5e+000    1.0
z[50]           1.1e+000 1.3e-002 4.7e-001  3.4e-001    1.1   1.9   1440 3.0e+000    1.0
x[50]           9.9e-001 3.3e-003 2.2e-001  6.6e-001   0.97   1.4   4366 9.1e+000    1.0

Selected parts of the summary of the run with the 'Kalman' version:

Inference for Stan model: randomwalk_kalman_model
1 chains: each with iter=(100000); warmup=(0); thin=(1); 100000 iterations saved.

Warmup took (2.1) seconds, 2.1 seconds total
Sampling took (523) seconds, 8.7 minutes total

                    Mean     MCSE   StdDev        5%       50%    95%    N_Eff  N_Eff/s    R_hat

sigma_z         4.3e-001 8.0e-004 1.5e-001  1.8e-001  4.3e-001   0.67 3.4e+004 6.5e+001 1.0e+000
sqrtQ           7.5e-002 2.6e-004 5.4e-002  8.4e-003  6.5e-002   0.18 4.3e+004 8.3e+001 1.0e+000
z[50]           1.1e+000 1.5e-003 4.8e-001  3.2e-001  1.1e+000    1.9 9.7e+004 1.8e+002 1.0e+000
x[50]           9.7e-001 8.2e-004 2.2e-001  6.6e-001  9.5e-001    1.4 7.4e+004 1.4e+002 1.0e+000

The numbers are close enough that it is at least very plausible that the samplers target the same distribution. The effective sample sizes (both per time and per the total number of samples) are clearly better with the Kalman version.

Remarks

  • I wonder how one should test that two Stan implementations really have the correct log-posterior, except by the annoying method of running long MCMC chains with both and agonizing over whether small differences in results are due to bugs or due to Monte Carlo error. I initially had a bug here (accidentally used \(R_k\) instead of \(S_k\) in the likelihood evaluation loop. Testing with smaller data, this caused a clear difference in results, but still it took me some time to convince myself that it is a bug instead of, e.g., one of the samplers just having poor performance.
  • If the samples of \(x\) and \(z\) were not required, the Kalman filter loop could be placed in the model block. I first did this and learned that the generated quantities block then does not have access to the intermediate results. Thus I was forced to put the Kalman filter in the transformed parameters block. I guess this otherwise does not matter, but at least with CmdStan the output file (and the stansummary output) will then contain all the uninteresting intermediate results. Or is there a way to suppress these variables from output even if they are used in the transformed parameters block?

References

  1. Särkkä, S. (2013). Bayesian Filtering and Smoothing. Cambridge University Press.
  2. Gelman, A., Carlin, J. B., Stern, H. S., and Rubin, D. B. (2004). Bayesian Data Analysis (Second edition). Chapman & Hall/CRC.