Source code for viabel.objectives

from abc import ABC, abstractmethod

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import value_and_grad, vector_jacobian_product, make_hvp, elementwise_grad, grad, hessian
from autograd.core import getval

__all__ = [
    'VariationalObjective',
    'StochasticVariationalObjective',
    'ExclusiveKL',
    'DISInclusiveKL',
    'AlphaDivergence'
]


[docs]class VariationalObjective(ABC): """A class representing a variational objective to minimize"""
[docs] def __init__(self, approx, model): """ Parameters ---------- approx : `ApproximationFamily` object model : `Model` object """ self._approx = approx self._model = model self._objective_and_grad = None self._update_objective_and_grad()
[docs] def __call__(self, var_param): """Evaluate objective and its gradient. May produce an (un)biased estimate of both. Parameters ---------- var_param : `numpy.ndarray`, shape (var_param_dim,) The variational parameter. """ if self._objective_and_grad is None: raise RuntimeError("no objective and gradient available") return self._objective_and_grad(var_param)
@abstractmethod def _update_objective_and_grad(self): """Update the objective and gradient function. Should be called whenever a parameter that the objective depends on (e.g., `approx` or `model`) is updated.""" def _hessian_vector_product(self, var_param, x): """Compute hessian vector product at given variaitonal parameter point x. """ pass
[docs] def update(self, var_param, direction): """Update the variational parameter in optimization.""" return var_param - direction
@property def approx(self): """The approximation family.""" return self._approx @approx.setter def approx(self, value): self._approx = value self._update_objective_and_grad() @property def model(self): """The model.""" return self._model @model.setter def model(self, value): self._model = value self._update_objective_and_grad()
class StochasticVariationalObjective(VariationalObjective): """A class representing a variational objective approximated using Monte Carlo.""" def __init__(self, approx, model, num_mc_samples): """ Parameters ---------- approx : `ApproximationFamily` object model : `Model` object num_mc_sample : `int` Number of Monte Carlo samples to use to approximate the objective. """ self._num_mc_samples = num_mc_samples super().__init__(approx, model) @property def num_mc_samples(self): """Number of Monte Carlo samples to use to approximate the objective.""" return self._num_mc_samples @num_mc_samples.setter def num_mc_samples(self, value): self._num_mc_samples = value self._update_objective_and_grad()
[docs]class ExclusiveKL(StochasticVariationalObjective): """Exclusive Kullback-Leibler divergence. Equivalent to using the canonical evidence lower bound (ELBO) with reparameterized gradient estimator and control variate This implementation of reparameterization and control variate is based on: "Reducing Reparameterization Gradient Variance" by Andrew C. Miller, Nicholas J. Foti , Alexander D'Amour , and Ryan P. Adams, Code based on the implementation by Andrew C. Miller: https://github.com/andymiller/ReducedVarianceReparamGradients """
[docs] def __init__(self, approx, model, num_mc_samples, use_path_deriv=False, hessian_approx_method=None): """ Parameters ---------- approx : `ApproximationFamily` object model : `Model` object num_mc_sample : `int` Number of Monte Carlo samples to use to approximate the objective. use_path_deriv : `bool` Use path derivative (for "sticking the landing") gradient estimator hessian_approx_method : 'string' Select from different methods for approximating the hessian: 'full' : use the full hessian matrix provided by BridgeStan 'mean_only' : use control variate only for mean estimator to avoid calculation of full hessian 'loo_diag_approx' : using "leave one out" method with hessian vector product value at other samples to estimate the diagonal values of hessian 'loo_direct_approx;: the same method as 'loo_diag_approx' but use the scaled approximation to the gradient of scale to do the "loo" estimation """ self._use_path_deriv = use_path_deriv if hessian_approx_method in [None, 'full', 'mean_only', 'loo_diag_approx', 'loo_direct_approx']: self.hessian_approx_method = hessian_approx_method else: raise ValueError("Name of approximation must be one of 'full', 'mean_only', 'loo_diag_approx', 'loo_direct_approx' or None object.") super().__init__(approx, model, num_mc_samples)
def _update_objective_and_grad(self): approx = self.approx if self.hessian_approx_method is None: def variational_objective(var_param): samples = approx.sample(var_param, self.num_mc_samples) if self._use_path_deriv: var_param_stopped = getval(var_param) lower_bound = np.mean( self.model(samples) - approx.log_density(var_param_stopped, samples)) elif approx.supports_entropy: lower_bound = np.mean(self.model(samples)) + approx.entropy(var_param) else: lower_bound = np.mean(self.model(samples) - approx.log_density(samples)) return -lower_bound self._hvp = make_hvp(variational_objective) self._objective_and_grad = value_and_grad(variational_objective) return def RGE(var_param): z_samples = approx.sample(var_param, self.num_mc_samples) m_mean, cov = approx.mean_and_cov(var_param) s_scale = np.sqrt(np.diag(cov)) epsilon_sample = (z_samples - m_mean) / s_scale # elbo = np.mean(self._model(z_samples) - approx.log_density(var_param, z_samples)) if self._use_path_deriv: var_param_stopped = getval(var_param) lower_bound = np.mean( self.model(z_samples) - approx.log_density(var_param_stopped, z_samples)) elif approx.supports_entropy: lower_bound = np.mean(self.model(z_samples)) + approx.entropy(var_param) else: lower_bound = np.mean(self.model(z_samples) - approx.log_density(z_samples)) # self.model takes in one single parameter to calculate grad and hessian def f_model(x): x = np.atleast_2d(x) return self._model(x) # estimate grad and hessian grad_f = elementwise_grad(self.model) grad_f_single = grad(f_model) dLdm = grad_f(z_samples) # log-std # dLds = dLdm * epsilon_sample + 1 / s_scale dLdlns = dLdm * epsilon_sample * s_scale + 1 # var_param MC gradient g_hat_rprm_grad = np.column_stack([dLdm, dLdlns]) # These implementation of using reparameterization and control variate to reduce variation if self.hessian_approx_method == "full": hessian_f = hessian(f_model) # Miller's implementation gmu = grad_f(m_mean) H = hessian_f(m_mean).squeeze() Hdiag = np.diag(H) # construct normal approx samples of data term dLdz = gmu + np.dot(H, (s_scale * epsilon_sample).T).T # dLds = (dLdz*eps + 1/s_lam[None,:]) * s_lam dLds = dLdz * epsilon_sample * s_scale + 1. elbo_gsamps_tilde = np.column_stack([dLdz, dLds]) # characterize the mean of the dLds component (and z comp) dLds_mu = (Hdiag * s_scale + 1 / s_scale) * s_scale gsamps_tilde_mean = np.concatenate([gmu, dLds_mu]) # subtract mean to compute control variate elbo_gsamps_cv = g_hat_rprm_grad - (elbo_gsamps_tilde - gsamps_tilde_mean) g_hat_rv = np.mean(elbo_gsamps_cv, axis=0) elif self.hessian_approx_method == "mean_only": # linear approximation of gradient: mean scaled_samples = np.multiply(s_scale, epsilon_sample) a = grad_f(m_mean * np.ones_like(z_samples)) hvp = make_hvp(f_model)(m_mean) b = np.array([hvp[0](s) for s in scaled_samples]) g_tilde_mean_approx = a + b # linear approximation of gradient: log-scale g_tilde_scale_approx_ln = np.zeros_like(g_tilde_mean_approx) # Expectation of linear approximation of gradient: mean E_g_tilde_mean = grad_f_single(m_mean) # Expectation of linear approximation of gradient: log-scale E_g_tilde_scale_ln = np.zeros_like(E_g_tilde_mean) g_tilde = np.column_stack([g_tilde_mean_approx, g_tilde_scale_approx_ln]) E_g_tilde = np.concatenate([E_g_tilde_mean, E_g_tilde_scale_ln]) E_g_tilde = np.multiply(E_g_tilde, np.ones_like(g_tilde)) g_hat_rv = np.mean(g_hat_rprm_grad - (g_tilde - E_g_tilde), axis=0) elif self.hessian_approx_method == "loo_diag_approx": """ use other samples to estimate a per-sample diagonal expectation """ # assert ns > 1, "loo approximations require more than 1 sample" # compute hessian vector products and save them for both parts hvp_lam = make_hvp(f_model)(m_mean)[0] hvps = np.array([hvp_lam(s_scale * e) for e in epsilon_sample]) gmu = grad_f(m_mean * np.ones_like(z_samples)) # construct normal approx samples of data term dLdz = gmu + hvps dLds = dLdz * (epsilon_sample * s_scale) + 1 # compute Leave One Out approximate diagonal (per-sample mean of dLds) Hdiag_sum = np.sum(epsilon_sample * hvps, axis=0) Hdiag_s = (Hdiag_sum[None, :] - epsilon_sample * hvps) / float(np.shape(z_samples)[0] - 1) dLds_mu = (Hdiag_s + 1 / s_scale[None, :]) * s_scale # compute gsamps_cv - mean(gsamps_cv), and finally the var reduced D = int(0.5 * np.shape(g_hat_rprm_grad)[1]) g_hat_rv = g_hat_rprm_grad.copy() g_hat_rv[:, :D] -= hvps g_hat_rv[:, D:] -= (dLds - dLds_mu) g_hat_rv = np.mean(g_hat_rv, axis=0) elif self.hessian_approx_method == "loo_direct_approx": hvp_lam = make_hvp(f_model)(m_mean)[0] gmu = grad_f(m_mean * np.ones_like(z_samples)) hvps = np.array([hvp_lam(s_scale * e) for e in epsilon_sample]) # construct normal approx samples of data term dLdz = gmu + hvps dLds = (dLdz * epsilon_sample + 1 / s_scale[None, :]) * s_scale # compute Leave One Out approximate diagonal (per-sample mean of dLds) dLds_sum = np.sum(dLds, axis=0) dLds_mu = (dLds_sum[None, :] - dLds) / float(np.shape(z_samples)[0] - 1) # compute gsamps_cv - mean(gsamps_cv), and finally the var reduced elbo_gsamps_tilde_centered = np.column_stack([hvps, dLds - dLds_mu]) g_hat_rv = np.mean(g_hat_rprm_grad - elbo_gsamps_tilde_centered, axis=0) else: raise RuntimeError("Invalid hessian approximation method!") return -lower_bound, -g_hat_rv self._objective_and_grad = RGE def _hessian_vector_product(self, var_param, x): hvp_fun = self._hvp(var_param)[0] return hvp_fun(x)
[docs]class DISInclusiveKL(StochasticVariationalObjective): """Inclusive Kullback-Leibler divergence using Distilled Importance Sampling."""
[docs] def __init__(self, approx, model, num_mc_samples, ess_target, temper_prior, temper_prior_params, use_resampling=True, num_resampling_batches=1, w_clip_threshold=10): """ Parameters ---------- approx : `ApproximationFamily` object model : `Model` object num_mc_sample : `int` Number of Monte Carlo samples to use to approximate the KL divergence. (N in the paper) ess_target: `int` The ess target to adjust epsilon (M in the paper). It is also the number of samples in resampling. temper_prior: `Model` object A prior distribution to temper the model. Typically multivariate normal. temper_prior_params: `numpy.ndarray` object Parameters for the temper prior. Typically mean 0 and variance 1. use_resampling: `bool` Whether to use resampling. num_resampling_batches: `int` Number of resampling batches. The resampling batch is `max(1, ess_target / num_resampling_batches)`. w_clip_threshold: `float` The maximum weight. """ self._ess_target = ess_target self._w_clip_threshold = w_clip_threshold self._max_bisection_its = 50 self._max_eps = self._eps = 1 self._use_resampling = use_resampling self._num_resampling_batches = num_resampling_batches self._resampling_batch_size = max(1, self._ess_target // num_resampling_batches) self._objective_step = 0 self._tempered_model_log_pdf = lambda eps, samples, log_p_unnormalized: ( eps * temper_prior.log_density(temper_prior_params, samples) + (1 - eps) * log_p_unnormalized) super().__init__(approx, model, num_mc_samples)
def _get_weights(self, eps, samples, log_p_unnormalized, log_q): """Calculates normalised importance sampling weights""" logw = self._tempered_model_log_pdf(eps, samples, log_p_unnormalized) - log_q max_logw = np.max(logw) if max_logw == -np.inf: raise ValueError('All weights zero! ' + 'Suggests overflow in importance density.') w = np.exp(logw) return w def _get_ess(self, w): """Calculates effective sample size of normalised importance sampling weights""" ess = (np.sum(w) ** 2.0) / np.sum(w ** 2.0) return ess def _get_eps_and_weights(self, eps_guess, samples, log_p_unnormalized, log_q): """Find new epsilon value Uses bisection to find epsilon < eps_guess giving required ESS. If none exists, returns eps_guess. Returns new epsilon value and corresponding ESS and normalised importance sampling weights. """ lower = 0. upper = eps_guess eps_guess = (lower + upper) / 2. for i in range(self._max_bisection_its): w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q) ess = self._get_ess(w) if ess > self._ess_target: upper = eps_guess else: lower = eps_guess eps_guess = (lower + upper) / 2. w = self._get_weights(eps_guess, samples, log_p_unnormalized, log_q) ess = self._get_ess(w) # Consider returning extreme epsilon values if they are still endpoints if lower == 0.: eps_guess = 0. if upper == self._max_eps: eps_guess = self._max_eps return eps_guess, ess, w def _clip_weights(self, w): """Clip weights to `self._w_clip_threshold` Other weights are scaled up proportionately to keep sum equal to 1""" S = np.sum(w) if not any(w > S * self._w_clip_threshold): return w to_clip = (w >= S * self._w_clip_threshold) # nb clip those equal to max_weight # so we don't push them over it! n_to_clip = np.sum(to_clip) to_not_clip = np.logical_not(to_clip) sum_unclipped = np.sum(w[to_not_clip]) if sum_unclipped == 0: # Impossible to clip further! return w w[to_clip] = self._w_clip_threshold * sum_unclipped(1. - self._w_clip_threshold * n_to_clip) return self._clip_weights(w) def _update_objective_and_grad(self): approx = self.approx def variational_objective(var_param): if not self._use_resampling or self._objective_step % self._num_resampling_batches == 0: self._state_samples = getval(approx.sample(var_param, self.num_mc_samples)) self._state_log_q = approx.log_density(var_param, self._state_samples) self._state_log_p_unnormalized = self.model(self._state_samples) self._eps, ess, w = self._get_eps_and_weights( self._eps, self._state_samples, self._state_log_p_unnormalized, self._state_log_q) self._state_w_clipped = self._clip_weights(w) self._state_w_sum = np.sum(self._state_w_clipped) self._state_w_normalized = self._state_w_clipped / self._state_w_sum self._objective_step += 1 if not self._use_resampling: return -np.inner(getval(self._state_w_clipped), self._state_log_q) / self.num_mc_samples else: indices = np.random.choice(self.num_mc_samples, size=self._resampling_batch_size, p=getval(self._state_w_normalized)) samples_resampled = self._state_samples[indices] obj = np.mean(-approx.log_density(var_param, getval(samples_resampled))) return obj * getval(self._state_w_sum) / self.num_mc_samples self._objective_and_grad = value_and_grad(variational_objective)
[docs]class AlphaDivergence(StochasticVariationalObjective): """Log of the alpha-divergence."""
[docs] def __init__(self, approx, model, num_mc_samples, alpha): """ Parameters ---------- approx : `ApproximationFamily` object model : `Model` object num_mc_sample : `int` Number of Monte Carlo samples to use to approximate the objective. alpha : `float` """ self._alpha = alpha super().__init__(approx, model, num_mc_samples)
@property def alpha(self): """Alpha parameter of the divergence.""" return self._alpha def _update_objective_and_grad(self): """Provides a stochastic estimate of the variational lower bound.""" def compute_log_weights(var_param, seed): samples = self.approx.sample(var_param, self.num_mc_samples, seed) log_weights = self.model(samples) - self.approx.log_density(var_param, samples) return log_weights log_weights_vjp = vector_jacobian_product(compute_log_weights) alpha = self.alpha # manually compute objective and gradient def objective_grad_and_log_norm(var_param): # must create a shared seed! seed = npr.randint(2 ** 32) log_weights = compute_log_weights(var_param, seed) log_norm = np.max(log_weights) scaled_values = np.exp(log_weights - log_norm) ** alpha obj_value = np.log(np.mean(scaled_values)) / alpha + log_norm obj_grad = alpha * log_weights_vjp(var_param, seed, scaled_values) / scaled_values.size return (obj_value, obj_grad) self._objective_and_grad = objective_grad_and_log_norm