Using Stan Models: A Robust Linear Regression Example

We will approximate the posterior for the simple 2D robust linear regression model

\[\beta_i \sim \mathcal{N}(0, 10)\]
\[y_n | x_n, \beta, \sigma \sim \mathcal{T}_{40}(\beta^\top x_n, 1)\]

and use Stan to compute (the gradient of) the model log density.

For more details and discussion of this example, see:

Practical posterior error bounds from variational objectives. Jonathan H. Huggins, Mikołaj Kasprzak, Trevor Campbell, Tamara Broderick. In Proc. of the 23rd International Conference on Artificial Intelligence and Statistics (AISTATS), Palermo, Italy. PMLR: Volume 108, 2020.

import os
import pickle
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import autograd.numpy as np
import pystan

from viabel import bbvi, vi_diagnostics, MultivariateT
# code for comparing to ground-truth posterior

sns.set_context('notebook', font_scale=1.5, rc={'lines.linewidth': 2})

def plot_approx_and_exact_contours(log_density, approx, var_param, cmap2='Reds'):
    xlim = [-4,-1]
    ylim = [-.5,3.5]
    xlist = np.linspace(*xlim, 100)
    ylist = np.linspace(*ylim, 100)
    X, Y = np.meshgrid(xlist, ylist)
    XY = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
    zs = np.exp(log_density(XY))
    Z = zs.reshape(X.shape)
    zsapprox = np.exp(approx.log_density(var_param, XY))
    Zapprox = zsapprox.reshape(X.shape)
    cs_post = plt.contour(X, Y, Z, cmap='Greys', linestyles='solid')
    cs_approx = plt.contour(X, Y, Zapprox, cmap=cmap2, linestyles='solid')

def check_accuracy(true_mean, true_cov, var_param, approx):
    approx_mean, approx_cov = approx.mean_and_cov(var_param)
    true_std = np.sqrt(np.diag(true_cov))
    approx_std = np.sqrt(np.diag(approx_cov))
    mean_error=np.linalg.norm(true_mean - approx_mean)
    cov_error_2=np.linalg.norm(true_cov - approx_cov, ord=2)
    cov_norm_2=np.linalg.norm(true_cov, ord=2)
    std_error=np.linalg.norm(true_std - approx_std)

    print('mean error             = {:.3g}'.format(mean_error))
    print('stdev error            = {:.3g}'.format(std_error))
    print('||cov error||_2^{{1/2}}  = {:.3g}'.format(np.sqrt(cov_error_2)))
    print('||true cov||_2^{{1/2}}   = {:.3g}'.format(np.sqrt(cov_norm_2)))

First, compile the robust regression Stan model:

compiled_model_file = 'robust_reg_model.pkl'
    with open(compiled_model_file, 'rb') as f:
        regression_model = pickle.load(f)
    regression_model = pystan.StanModel(file='robust_regression.stan',
with open(compiled_model_file, 'wb') as f:
    pickle.dump(regression_model, f)

Next, to as use a data, generate 25 observations from the model with \(\beta = (-2, 1)\):

beta_gen = np.array([-2, 1])
N = 25
x = np.random.randn(N, 2).dot(np.array([[1,.75],[.75, 1]]))
y_raw = + np.random.standard_t(40, N)
y = y_raw - np.mean(y_raw)

For illustration purposes, generate ground-truth posterior samples using Stan’s dynamic HMC implementation:

data = dict(N=N, x=x, y=y, df=40)
fit = regression_model.sampling(data=data, iter=50000, thin=50, chains=4)
true_mean = np.mean(fit['beta'], axis=0)
true_cov = np.cov(fit['beta'].T)

Standard mean-field variational inference

As a first example, we compute a mean field variational approximation using standard variational inference – that is, by maximimizing the evidence lower bound (ELBO):

mf_results = bbvi(2, fit=fit, num_mc_samples=50)
average loss = 21.352 | learning rate = 0.025 | (-0.082572, 0.093071):  30%|███       | 3000/10000 [00:09<00:22, 309.39it/s]
Stopping rule reached at iteration 3000

We can check approximation quality using vi_diagnostics, which determines the approximation is not good:

mf_objective = mf_results['objective']
with warnings.catch_warnings():
    diagnostics = vi_diagnostics(mf_results['var_param'], objective=mf_objective,
Pareto k is estimated to be khat = 0.91
WARNING: khat > 0.7 means importance sampling is not feasible.
WARNING: not running further diagnostics

Indeed, due to the strong posterior correlation, the variational approximation dramatically underestimates uncertainty:

plot_approx_and_exact_contours(mf_objective.model, mf_objective.approx,

We can confirm the poor approximation quality numerically by examining the mean, standard deviation, and covariance errors as compared to the ground-truth estimates:

check_accuracy(true_mean, true_cov, mf_results['var_param'], mf_objective.approx)
mean error             = 0.0142
stdev error            = 0.718
||cov error||_2^{1/2}  = 0.905
||true cov||_2^{1/2}   = 0.916

An approximation with full covariance

To get a good approximation, we can instead use a Multivariate t variational family with a full-rank scaling matrix:

t_results = bbvi(2, n_iters=2500, fit=fit, approx=MultivariateT(2, 100), num_mc_samples=100)
Approximation does not support KL. Using base stochastic optimization algorithm instead.
average loss = 23.05: 100%|██████████| 2500/2500 [00:13<00:00, 182.72it/s]

The diagnostics suggest the approximation is accurate:

t_objective = t_results['objective']
with warnings.catch_warnings():
    diagnostics = vi_diagnostics(t_results['var_param'], objective=t_objective,
Pareto k is estimated to be khat = -0.20

The 2-divergence is estimated to be d2 = 0.011

All diagnostics pass.

Visual inspection and numerical checks confirm the diagnostics:

plot_approx_and_exact_contours(t_objective.model, t_objective.approx, t_results['var_param'])
check_accuracy(true_mean, true_cov, t_results['var_param'], t_objective.approx)
mean error             = 0.00852
stdev error            = 0.0327
||cov error||_2^{1/2}  = 0.242
||true cov||_2^{1/2}   = 0.916