from viabel._psis import psislw
from viabel.approximations import MFGaussian
from viabel.diagnostics import all_diagnostics
from viabel.models import Model, StanModel
from viabel.objectives import ExclusiveKL
from viabel.optimization import RAABBVI, FASO, RMSProp
all = [
'bbvi',
'vi_diagnostics',
]
[docs]def bbvi(dimension, *, n_iters=10000, num_mc_samples=10, log_density=None,
approx=None, objective=None, fit=None, adaptive=True, fixed_lr=False,
init_var_param=None, learning_rate=0.01,
RMS_kwargs=dict(), FASO_kwargs=dict(), RAABBVI_kwargs=dict()):
"""Fit a model using black-box variational inference.
Currently the objective is optimized using ``viabel.optimization.FASO``.
Parameters
----------
dimension : `int`
Dimension of the model parameter.
n_iters : `int`, optional
Number of iterations of the optimization.
num_mc_samples : `int`, optional
Number of Monte Carlo samples to use for estimating the gradient of
the objective.
log_density : `function`, optional
(Unnormalized) log density of the model. Must support automatic
differentiation with ``autograd``. Either ``log_density`` or ``fit``
must be provided.
approx : `ApproximationFamily` object, optional
The approximation family. The default is to use ``viabel.approximations.MFGaussian``.
objective : `VariationalObjective` class
The default is to use ``viabel.objectives.ExclusiveKL``.
fit : `StanFit4model` object, optional
If provided, a ``StanModel`` will be used. Both ``fit`` and
``log_density`` cannot be given.
init_var_param, optional
Initial variational parameter.
adaptive : `bool`, optional
If ``True``, use ``FASO`` with ``RMSProp``. Otherwise use ``RMSProp``.
fixed_lr : `bool`, optional
If ``True``, use ``FASO`` with ``RMSProp`` or ``RMSProp``. Otherwise use ``RAABBVI``.
learning_rate : `float`
Tuning parameter that determines the step size.
RMS_kwargs : `dict`, optional
Dictionary of keyword arguments to pass to ``RMSProp``.
FASO_kwargs : `dict`, optional
Dictionary of keyword arguments to pass to ``FASO``.
RAABBVI_kwargs : `dict`, optional
Dictionary of keyword arguments to pass to ``RAABBVI``.
Returns
-------
results : `dict`
Contains the following entries: `objective` and results from optimizer
"""
if objective is not None:
if fit is not None or log_density is not None or approx is not None:
raise ValueError(
'if objective is specified, cannot specify fit, log_density, or approx')
approx = objective.approx
model = objective.model
else:
if log_density is None:
if fit is None:
raise ValueError(
'either log_density or fit must be specified if objective not given')
model = StanModel(fit)
elif fit is None:
model = Model(log_density)
else:
raise ValueError('log_density and fit cannot both be specified')
if approx is None:
approx = MFGaussian(dimension)
objective = ExclusiveKL(approx, model, num_mc_samples)
if init_var_param is None:
init_var_param = approx.init_param()
base_opt = RMSProp(learning_rate, diagnostics=True, **RMS_kwargs)
if adaptive and not fixed_lr:
opt = RAABBVI(base_opt, **RAABBVI_kwargs)
elif adaptive and fixed_lr:
opt = FASO(base_opt, **FASO_kwargs)
elif not adaptive and fixed_lr:
opt = base_opt
else:
raise ValueError('if fixed_lr is False, adaptive must be True')
opt_results = opt.optimize(n_iters, objective, init_var_param)
opt_results['objective'] = objective
return opt_results
[docs]def vi_diagnostics(var_param, *, objective=None, model=None, approx=None, n_samples=100000):
"""Check variational inference diagnostics.
Check Pareto k and 2-divergence diagnostics. Return additional diagnostics
with mean, standard deviation, and covariance error bounds.
Parameters
----------
var_param : `numpy.ndarray`, shape (var_param_dim,)
The variational parameter.
objective : `function`
model : `Model` object
approx : `ApproximationFamily` object
n_samples : `int`
The number of samples to use for the diagnostics.
Returns
-------
diagnostics : `dict`
Also includes samples and smoothed log weights.
See Also
--------
diagostics.all_diagnostics : Compute all diagnostics.
"""
if objective is None:
if model is None or approx is None:
raise ValueError('either objective or both model and approx must be specified')
elif model is not None or approx is not None:
raise ValueError('model and/or approx cannot be specified if objective is')
else:
model = objective.model
approx = objective.approx
if n_samples <= 0:
raise ValueError('n_samples must be positive')
return _vi_diagnostics(var_param, model, approx, n_samples)
def _vi_diagnostics(var_param, model, approx, n_samples):
# first check Pareto k-hat
samples, smoothed_log_weights, khat = psis_correction(var_param, model, approx, n_samples)
results = dict(samples=samples,
smoothed_log_weights=smoothed_log_weights,
khat=khat)
print('Pareto k is estimated to be khat = {:.2f}'.format(results['khat']))
if results['khat'] > 0.7:
print('WARNING: khat > 0.7 means importance sampling is not feasible.')
print('WARNING: not running further diagnostics')
return results
print()
# if k-hat looks good, check other diagnostics
if approx.supports_pth_moment(2) and approx.supports_pth_moment(4):
def moment_bound_fn(p):
return approx.pth_moment(var_param, p)
else:
moment_bound_fn = None
_, q_var = approx.mean_and_cov(var_param)
results.update(all_diagnostics(smoothed_log_weights,
samples=samples,
moment_bound_fn=moment_bound_fn,
q_var=q_var))
print('The 2-divergence is estimated to be d2 = {:.2g}'.format(results['d2']))
if results['d2'] > 4.6: # pragma: no cover
print('WARNING: d2 > 4.6 means the approximation is very inaccurate')
elif results['d2'] > 0.1:
print('WARNING: 0.1 < d2 < 4.6 means the approximation is somewhat '
'inaccurate. Use importance sampling to decrease error.')
else:
print('\nAll diagnostics pass.')
return results
def psis_correction(var_param, model, approx, n_samples):
samples, log_weights = samples_and_log_weights(var_param, model, approx, n_samples)
smoothed_log_weights, khat = psislw(log_weights, overwrite_lw=True)
return samples.T, smoothed_log_weights, khat
def samples_and_log_weights(var_param, model, approx, n_samples):
samples = approx.sample(var_param, n_samples)
log_weights = model(samples) - approx.log_density(var_param, samples)
return samples, log_weights