Quickstart¶
VIABEL currently supports both standard KL-based variational inference (KLVI) and chi-squared variational inference (CHIVI). Models are provided as Autograd-compatible log densities or can be constructed from PyStan fit objects.
As a simple example, we consider Neal’s funnel distribution in 2 dimensions so that we can visualize the results.
[1]:
import autograd.numpy as np
import autograd.scipy.stats.norm as norm
D = 2 # number of dimensions
log_sigma_stdev = 1. # 1.35
def log_density(x):
mu, log_sigma = x[:, 0], x[:, 1]
sigma_density = norm.logpdf(log_sigma, 0, log_sigma_stdev)
mu_density = norm.logpdf(mu, 0, np.exp(log_sigma))
return sigma_density + mu_density
Black-box Variational Inference¶
VIABEL’s bbvi
function provides reasonable defaults: the objective is the ELBO (i.e., the including Kullback-Leibler divergence), a mean-field Gaussian approximation family, and automated RMSProp with adaptive step reduction and stopping rule.
[2]:
from viabel import bbvi, MFStudentT
results = bbvi(D, log_density=log_density)
average loss = 0.5564 | R hat converged|: 61%|██████ | 6122/10000 [00:16<00:10, 375.88it/s]
Convergence reached at iteration 6122
We can then plot contours the from the approximation Gaussian (red) and the target funnel distribution (black)
[3]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
sns.set_context('notebook', font_scale=2, rc={'lines.linewidth': 2})
def plot_approx_and_exact_contours(log_density, approx, var_param,
xlim, ylim, cmap2='Reds'):
xlist = np.linspace(*xlim, 100)
ylist = np.linspace(*ylim, 100)
X, Y = np.meshgrid(xlist, ylist)
XY = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
zs = np.exp(log_density(XY))
Z = zs.reshape(X.shape)
zsapprox = np.exp(approx.log_density(var_param, XY))
Zapprox = zsapprox.reshape(X.shape)
plt.contour(X, Y, Z, cmap='Greys', linestyles='solid')
plt.contour(X, Y, Zapprox, cmap=cmap2, linestyles='solid')
plt.xlim(xlim)
plt.ylim(ylim)
plt.show()
plot_approx_and_exact_contours(log_density, results['objective'].approx, results['opt_param'],
xlim=[-2, 2], ylim=[-3.25, 2])
Diagnostics¶
VIABEL also has a suite of diagostics for variational inference. We can easily run these using the vi_diagnostics
function, although low-level support is also provided.
[4]:
from viabel import vi_diagnostics
import warnings
with warnings.catch_warnings():
warnings.simplefilter('ignore')
diagnostics = vi_diagnostics(results['opt_param'], objective=results['objective'], n_samples=100000)
Pareto k is estimated to be khat = 0.67
The 2-divergence is estimated to be d2 = 3.9
WARNING: 0.1 < d2 < 4.6 means the approximation is somewhat inaccurate. Use importance sampling to decrease error.
Importance Sampling¶
The Pareto-smoothed weights provide a fairly accuracy approximation.
[5]:
weights = np.exp(diagnostics['smoothed_log_weights'])
all_samples = diagnostics['samples']
# sample 1000 weights for visualization
subset = np.random.choice(all_samples.shape[1], size=1000, p=weights)
samples = all_samples[:,subset]
plt.plot(samples[0], samples[1], '*', alpha=.3)
plot_approx_and_exact_contours(log_density, results['objective'].approx, results['opt_param'],
xlim=[-2, 2], ylim=[-3.25, 2])