Source code for viabel.approximations

from abc import ABC, abstractmethod

import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.norm as norm
import autograd.scipy.stats.t as t_dist
from autograd import elementwise_grad
from autograd.scipy.linalg import sqrtm
from paragami import (
    FlattenFunctionInput, NumericArrayPattern, NumericVectorPattern, PatternDict,
    PSDSymmetricMatrixPattern)

from ._distributions import multivariate_t_logpdf

__all__ = [
    'ApproximationFamily',
    'MFGaussian',
    'MFStudentT',
    'MultivariateT',
    'NeuralNet',
    'NVPFlow',
    'LRGaussian'
]


[docs]class ApproximationFamily(ABC): """An abstract class for an variational approximation family. See derived classes for examples. """
[docs] def __init__(self, dim, var_param_dim, supports_entropy, supports_kl): """ Parameters ---------- dim : `int` The dimension of the space the distributions in the approximation family are defined on. var_param_dim : `int` The dimension of the variational parameter. supports_entropy : `bool` Whether the approximation family supports closed-form entropy computation. supports_kl : `bool` Whether the approximation family supports closed-form KL divergence computation. """ self._dim = dim self._var_param_dim = var_param_dim self._supports_entropy = supports_entropy self._supports_kl = supports_kl
[docs] def init_param(self): """A variational parameter to use for initialization. Returns ------- var_param : `numpy.ndarray`, shape (var_param_dim,) """ return np.zeros(self.var_param_dim)
[docs] @abstractmethod def sample(self, var_param, n_samples, seed=None): """Generate samples from the variational distribution Parameters ---------- var_param : `numpy.ndarray`, shape (var_param_dim,) The variational parameter. n_samples : `int` The number of samples to generate. Returns ------- samples : `numpy.ndarray`, shape (n_samples, var_param_dim) """
[docs] def entropy(self, var_param): """Compute entropy of variational distribution. Parameters ---------- var_param : `numpy.ndarray`, shape (var_param_dim,) The variational parameter. Raises ------ NotImplementedError If entropy computation is not supported.""" if self.supports_entropy: return self._entropy(var_param) raise NotImplementedError()
def _entropy(self, var_param): raise NotImplementedError() @property def supports_entropy(self): """Whether the approximation family supports closed-form entropy computation.""" return self._supports_entropy
[docs] def kl(self, var_param0, var_param1): """Compute the Kullback-Leibler (KL) divergence. Parameters ---------- var_param0, var_param1 : `numpy.ndarray`, shape (var_param_dim,) The variational parameters. Raises ------ NotImplementedError If KL divergence computation is not supported. """ if self.supports_kl: return self._kl(var_param0, var_param1) raise NotImplementedError()
def _kl(self, var_param): raise NotImplementedError() @property def supports_kl(self): """Whether the approximation family supports closed-form KL divergence computation.""" return self._supports_kl
[docs] @abstractmethod def log_density(self, var_param, x): """The log density of the variational distribution. Parameters ---------- var_param : `numpy.ndarray`, shape (var_param_dim,) The variational parameter. x : `numpy.ndarray`, shape (dim,) Value at which to evaluate the density."""
[docs] @abstractmethod def mean_and_cov(self, var_param): """The mean and covariance of the variational distribution. Parameters ---------- var_param : `numpy.ndarray`, shape (var_param_dim,) The variational parameter. """
[docs] def pth_moment(self, var_param, p): """The absolute pth moment of the variational distribution. The absolute pth moment is given by :math:`\\mathbb{E}[|X|^p]`. Parameters ---------- var_param : `numpy.ndarray`, shape (var_param_dim,) The variational parameter. p : `int` Raises ------ ValueError If `p` value not supported""" if self.supports_pth_moment(p): return self._pth_moment(var_param, p) raise ValueError('p = {} is not a supported moment'.format(p))
@abstractmethod def _pth_moment(self, var_param, p): """Get pth moment of the approximating distribution"""
[docs] @abstractmethod def supports_pth_moment(self, p): """Whether analytically computing the pth moment is supported"""
@property def dim(self): """Dimension of the space the distribution is defined on""" return self._dim @property def var_param_dim(self): """Dimension of the variational parameter""" return self._var_param_dim
def _get_mu_log_sigma_pattern(dim): ms_pattern = PatternDict(free_default=True) ms_pattern['mu'] = NumericVectorPattern(length=dim) ms_pattern['log_sigma'] = NumericVectorPattern(length=dim) return ms_pattern
[docs]class MFGaussian(ApproximationFamily): """A mean-field Gaussian approximation family."""
[docs] def __init__(self, dim, seed=1): """Create mean field Gaussian approximation family. Parameters ---------- dim : `int` dimension of the underlying parameter space """ self._rs = npr.RandomState(seed) self._pattern = _get_mu_log_sigma_pattern(dim) super().__init__(dim, self._pattern.flat_length(True), True, True)
[docs] def init_param(self): init_param_dict = dict(mu=np.zeros(self.dim), log_sigma=2 * np.ones(self.dim)) return self._pattern.flatten(init_param_dict)
[docs] def sample(self, var_param, n_samples, seed=None): my_rs = self._rs if seed is None else npr.RandomState(seed) param_dict = self._pattern.fold(var_param) return param_dict['mu'] + np.exp(param_dict['log_sigma']) * \ my_rs.randn(n_samples, self.dim)
def _entropy(self, var_param): param_dict = self._pattern.fold(var_param) return 0.5 * self.dim * (1.0 + np.log(2 * np.pi)) + np.sum(param_dict['log_sigma']) def _kl(self, var_param0, var_param1): param_dict0 = self._pattern.fold(var_param0) param_dict1 = self._pattern.fold(var_param1) mean_diff = param_dict0['mu'] - param_dict1['mu'] log_stdev_diff = param_dict0['log_sigma'] - param_dict1['log_sigma'] return .5 * np.sum(np.exp(2 * log_stdev_diff) + mean_diff ** 2 / np.exp(2 * param_dict1['log_sigma']) - 2 * log_stdev_diff - 1)
[docs] def log_density(self, var_param, x): if x.ndim == 1: x = x[np.newaxis, :] param_dict = self._pattern.fold(var_param) return np.sum(norm.logpdf(x, param_dict['mu'], np.exp(param_dict['log_sigma'])), axis=-1)
[docs] def mean_and_cov(self, var_param): param_dict = self._pattern.fold(var_param) return param_dict['mu'], np.diag(np.exp(2 * param_dict['log_sigma']))
def _pth_moment(self, var_param, p): param_dict = self._pattern.fold(var_param) vars = np.exp(2 * param_dict['log_sigma']) if p == 2: return np.sum(vars) else: # p == 4 return 2 * np.sum(vars**2) + np.sum(vars)**2
[docs] def supports_pth_moment(self, p): return p in [2, 4]
[docs]class MFStudentT(ApproximationFamily): """A mean-field Student's t approximation family."""
[docs] def __init__(self, dim, df, seed=1): if df <= 2: raise ValueError('df must be greater than 2') self._df = df self._rs = npr.RandomState(seed) self._pattern = _get_mu_log_sigma_pattern(dim) super().__init__(dim, self._pattern.flat_length(True), True, False)
[docs] def init_param(self): init_param_dict = dict(mu=np.zeros(self.dim), log_sigma=2 * np.ones(self.dim)) return self._pattern.flatten(init_param_dict)
[docs] def sample(self, var_param, n_samples, seed=None): my_rs = self._rs if seed is None else npr.RandomState(seed) param_dict = self._pattern.fold(var_param) return param_dict['mu'] + np.exp(param_dict['log_sigma']) * \ my_rs.standard_t(self.df, size=(n_samples, self.dim))
[docs] def entropy(self, var_param): # ignore terms that depend only on df param_dict = self._pattern.fold(var_param) return np.sum(param_dict['log_sigma'])
[docs] def log_density(self, var_param, x): if x.ndim == 1: x = x[np.newaxis, :] param_dict = self._pattern.fold(var_param) return np.sum(t_dist.logpdf(x, self.df, param_dict['mu'], np.exp( param_dict['log_sigma'])), axis=-1)
[docs] def mean_and_cov(self, var_param): param_dict = self._pattern.fold(var_param) df = self.df cov = df / (df - 2) * np.diag(np.exp(2 * param_dict['log_sigma'])) return param_dict['mu'], cov
def _pth_moment(self, var_param, p): df = self.df if df <= p: raise ValueError('df must be greater than p') param_dict = self._pattern.fold(var_param) scales = np.exp(param_dict['log_sigma']) c = df / (df - 2) if p == 2: return c * np.sum(scales**2) else: # p == 4 return c**2 * (2 * (df - 1) / (df - 4) * np.sum(scales**4) + np.sum(scales**2)**2)
[docs] def supports_pth_moment(self, p): return p in [2, 4] and p < self.df
@property def df(self): """Degrees of freedom.""" return self._df
def _get_mu_sigma_pattern(dim): ms_pattern = PatternDict(free_default=True) ms_pattern['mu'] = NumericVectorPattern(length=dim) ms_pattern['Sigma'] = PSDSymmetricMatrixPattern(size=dim) return ms_pattern
[docs]class MultivariateT(ApproximationFamily): """A full-rank multivariate t approximation family."""
[docs] def __init__(self, dim, df, seed=1): if df <= 2: raise ValueError('df must be greater than 2') self._df = df self._rs = npr.RandomState(seed) self._pattern = _get_mu_sigma_pattern(dim) self._log_density = FlattenFunctionInput( lambda param_dict, x: multivariate_t_logpdf( x, param_dict['mu'], param_dict['Sigma'], df), patterns=self._pattern, free=True, argnums=0) super().__init__(dim, self._pattern.flat_length(True), True, False)
[docs] def init_param(self): init_param_dict = dict(mu=np.zeros(self.dim), Sigma=10 * np.eye(self.dim)) return self._pattern.flatten(init_param_dict)
[docs] def sample(self, var_param, n_samples, seed=None): my_rs = self._rs if seed is None else npr.RandomState(seed) df = self.df s = np.sqrt(my_rs.chisquare(df, n_samples) / df) param_dict = self._pattern.fold(var_param) z = my_rs.randn(n_samples, self.dim) sqrtSigma = sqrtm(param_dict['Sigma']) return param_dict['mu'] + np.dot(z, sqrtSigma) / s[:, np.newaxis]
[docs] def entropy(self, var_param): # ignore terms that depend only on df param_dict = self._pattern.fold(var_param) return .5 * np.log(np.linalg.det(param_dict['Sigma']))
[docs] def log_density(self, var_param, x): return self._log_density(var_param, x)
[docs] def mean_and_cov(self, var_param): param_dict = self._pattern.fold(var_param) df = self.df return param_dict['mu'], df / (df - 2.) * param_dict['Sigma']
def _pth_moment(self, var_param, p): df = self.df if df <= p: raise ValueError('df must be greater than p') param_dict = self._pattern.fold(var_param) sq_scales = np.linalg.eigvalsh(param_dict['Sigma']) c = df / (df - 2) if p == 2: return c * np.sum(sq_scales) else: # p == 4 return c**2 * (2 * (df - 1) / (df - 4) * np.sum(sq_scales**2) + np.sum(sq_scales)**2)
[docs] def supports_pth_moment(self, p): return p in [2, 4] and p < self.df
@property def df(self): """Degrees of freedom.""" return self._df
[docs]class NeuralNet(ApproximationFamily):
[docs] def __init__(self, layers_shapes, nonlinearity=np.tanh, last=np.tanh, mc_samples=10000, seed=1): """ Parameters ---------- layers_shapes : `list of int` The hidden layers dimensions. nonlinearity : `function` Non linear function to apply after each layer except the last layer. last : `function` Non linear function to apply after the last layer. mc_samples : `int` Number of samples to draw internally for computing mean and cov. seed : `int` Internal seed representation. """ self._pattern = PatternDict(free_default=True) self.mc_samples = mc_samples self._layers = len(layers_shapes) self._nonlinearity = nonlinearity self._last = last self._rs = npr.RandomState(seed) self.input_dim = layers_shapes[0][0] for layer_id in range(len(layers_shapes)): self._pattern[str(layer_id)] = NumericArrayPattern(shape=layers_shapes[layer_id]) self._pattern[str(layer_id) + "_b"] = NumericArrayPattern( shape=[layers_shapes[layer_id][1]]) super().__init__(layers_shapes[-1][-1], self._pattern.flat_length(True), False, False)
def forward(self, var_param, x): log_det_J = np.zeros(x.shape[0]) derivative = elementwise_grad(self._nonlinearity) derivative_last = elementwise_grad(self._last) for layer_id in range(self._layers): W = var_param[str(layer_id)] b = var_param[str(layer_id) + "_b"] if layer_id + 1 == self._layers: x = self._last(np.dot(x, W) + b) log_det_J += np.log(np.abs(np.dot(derivative_last(x), W.T).sum(axis=1))) else: x = self._nonlinearity(np.dot(x, W) + b) log_det_J += np.log(np.abs(np.dot(derivative(x), W.T).sum(axis=1))) return x, log_det_J
[docs] def sample(self, var_param, n_samples): z_0 = npr.multivariate_normal(mean=[0] * self.input_dim, cov=np.identity(self.input_dim), size=n_samples) z_k, _ = self.forward(var_param, z_0) return z_k
[docs] def log_density(self, var_param, x): raise NotImplementedError
[docs] def mean_and_cov(self, var_param): samples = self.sample(var_param, self.mc_samples) return np.mean(samples, axis=0), np.cov(samples.T)
def _pth_moment(self, var_param, p): raise NotImplementedError
[docs] def supports_pth_moment(self, p): return False
[docs]class NVPFlow(ApproximationFamily):
[docs] def __init__(self, layers_t, layers_s, mask, prior, prior_param, dim, activation=np.tanh, seed=1, mc_samples=10000): """ Parameters ---------- layers_t : `list of int` The hidden layers dimensions for the translation operator. layers_s : `list of int` The hidden layers dimensions for the scaling operator. mask : `mask int` Mask to apply to the entry of each operator. prior : `ApproximationFamily` Prior for the latent space Z. prior_param : `numpy array` Parameter vector for the prior, must follow the same format as any variational family. dim : `int` Input dimension. seed : `int` Random seed for reproducibility. mc_samples : `int` Number of samples to draw internally for computing mean and cov. """ assert len(layers_t) == len(layers_s) self.prior = prior self.prior_param = prior_param self.mc_samples = mc_samples self._dim = dim self._rs = npr.RandomState(seed) self.mask = mask self._pattern = PatternDict(free_default=True) self.t = [NeuralNet(layers_t, nonlinearity=activation, last=lambda x: x) for _ in range(len(mask))] self.s = [NeuralNet(layers_s, nonlinearity=activation, last=np.tanh) for _ in range(len(mask))] for layer_id in range(len(mask)): self._pattern[str(layer_id) + "t"] = self.t[layer_id]._pattern self._pattern[str(layer_id) + "s"] = self.s[layer_id]._pattern super().__init__(dim, self._pattern.flat_length(True), False, False)
[docs] def g(self, var_param, z): """Inverse NVP flow. Parameters ---------- var_param : `numpy array` Flat array of variational parameters. z : `numpy array` Latent space sample. """ x = z param_dict = self._pattern.fold(var_param) for i in range(len(self.t)): x_ = x * self.mask[i] s = self.s[i].forward(param_dict[str(i) + "s"], x_)[0] * (1 - self.mask[i]) t = self.t[i].forward(param_dict[str(i) + "t"], x_)[0] * (1 - self.mask[i]) x = x_ + (1 - self.mask[i]) * (x * np.exp(s) + t) return x
[docs] def f(self, var_param, x): """Forward NVP flow. Parameters ---------- var_param : `numpy array` Flat array of variational parameters. x : `numpy array` Original space data. """ param_dict = self._pattern.fold(var_param) log_det_J, z = np.zeros(x.shape[0]), x for i in reversed(range(len(self.t))): z_ = self.mask[i] * z s = self.s[i].forward(param_dict[str(i) + "s"], z_)[0] * (1 - self.mask[i]) t = self.t[i].forward(param_dict[str(i) + "t"], z_)[0] * (1 - self.mask[i]) z = (1 - self.mask[i]) * (z - t) * np.exp(-s) + z_ log_det_J -= s.sum(axis=1) return z, log_det_J
[docs] def log_density(self, var_param, x): z, logp = self.f(var_param, x) return self.prior.log_density(self.prior_param, z) + logp
[docs] def sample(self, var_param, n_samples, seed=None): z_0 = self.prior.sample(self.prior_param, int(n_samples), seed=seed) z_k = self.g(var_param, z_0) return z_k
[docs] def mean_and_cov(self, var_param): samples = self.sample(var_param, self.mc_samples) return np.mean(samples, axis=0), np.cov(samples.T)
def _pth_moment(self, var_param, p): raise NotImplementedError
[docs] def supports_pth_moment(self, p): return False
def _get_low_rank_mu_sigma_pattern(dim, k): ms_pattern = PatternDict(free_default=True) ms_pattern['mu'] = NumericVectorPattern(length=dim) ms_pattern['log_sigma'] = NumericVectorPattern(length=dim) ms_pattern['low_rank'] = NumericArrayPattern(shape=(dim, k)) return ms_pattern def _get_log_determinant(D, B): """Compute the determinant of the matrix B @ B.T + np.diag(D) using the matrix determinant lemma. Parameters ---------- D : `numpy vector` diagnal component of covariance matrix. B : `numpy array` low rank component of covariance matrix. """ log_det_D = 2*np.sum(D) _,log_det_IpDBBT = np.linalg.slogdet(np.eye(len(D)) + B @ B.T/np.exp(2*D[:,np.newaxis])) log_det_M = log_det_D + log_det_IpDBBT return log_det_M def _get_trace(D0, B0, D1, B1): """Compute the trace of the product of the inverse of B1 @ B1.T + np.diag(D1) and B0 @ B0.T + np.diag(D0). Parameters ---------- D0 : `numpy vector` Diagonal elements of sigma0. B0 : `numpy vector` Low-rank component of sigma0. D1 : `numpy array` Diagonal elements of sigma1. B1 : `numpy array` Low-rank component of sigma1. """ I_B1D1B1 = np.eye(B1.shape[1]) + B1.T / D1 @ B1 invD1_B1 = B1 / D1[:, np.newaxis] invD1_B1_I_B1D1B1_inv = np.linalg.solve(I_B1D1B1.T, invD1_B1.T).T product = invD1_B1_I_B1D1B1_inv @ (B1.T / D1) # Compute Tr(D0 * B1 * (I + B1^T * D1^-1 * B1)^-1 * B1^T * D1^-1) trace_product = np.trace(product * D0) # Compute Tr(D0 * D1^-1) trace_D0_invD1 = np.sum(D0 / D1) # Compute Tr(np.diag(D1)^(-1) * B0 @ B0.T) trace_invD1_B0B0T = np.trace(B0 @ B0.T / D1) # Compute Tr(np.diag(D1)^(-1) * B1 * (I + B1.T * D1^-1 * B1)^-1 * B1^T * D1^-1 * B0 @ B0.T) trace_extra_term = np.trace(product @ B0 @ B0.T) # Return Tr(sigma0 * sigma1^-1) = Tr(D0 * D1^-1) + Tr(np.diag(D1)^(-1) * B0 @ B0.T) - Tr(D0 * np.diag(D1)^(-1) * B1 * (I + B1^T * D1^-1 * B1)^-1 * B1^T * D1^-1) - Tr(np.diag(D1)^(-1) * B1 * (I + B1.T * D1^-1 * B1)^-1 * B1^T * D1^-1 * B0 @ B0.T) return trace_D0_invD1 + trace_invD1_B0B0T - trace_product - trace_extra_term class LRGaussian(ApproximationFamily): """A low rank Gaussian approximation family.""" def __init__(self, dim, seed=1, k=0): """Create multivariate Gaussian approximation family. Parameters ---------- dim : `int` dimension of the underlying parameter space k : 'int' number of low rank """ self._rs = npr.RandomState(seed) self._pattern = _get_low_rank_mu_sigma_pattern(dim, k) self._k = k super().__init__(dim, self._pattern.flat_length(True), True, True) def init_param(self): init_param_dict = dict(mu=np.zeros(self.dim), log_sigma=np.ones(self.dim), low_rank=self._rs.randn(self.dim,self._k)) return self._pattern.flatten(init_param_dict) def sample(self, var_param, n_samples, seed=None): my_rs = self._rs if seed is None else npr.RandomState(seed) param_dict = self._pattern.fold(var_param) z = my_rs.randn(n_samples,self._k) epsilon = my_rs.randn(n_samples, self.dim) D_exp = np.exp(param_dict['log_sigma']) B = param_dict['low_rank'] return param_dict['mu'] + np.dot(z,B.T) + D_exp * epsilon def _entropy(self, var_param): param_dict = self._pattern.fold(var_param) B = param_dict['low_rank'] D = param_dict['log_sigma'] sigma_log_det = _get_log_determinant(D, B) return 0.5 * self.dim * (np.log(2 * np.pi ) + 1) + 0.5 * sigma_log_det def _kl(self, var_param0, var_param1): param_dict0 = self._pattern.fold(var_param0) param_dict1 = self._pattern.fold(var_param1) mean_diff = param_dict0['mu'] - param_dict1['mu'] B0 = param_dict0['low_rank'] D0 = param_dict0['log_sigma'] D0_exp = np.exp(2 * param_dict0['log_sigma']) sigma0_log_det = _get_log_determinant(D0, B0) B1 = param_dict1['low_rank'] D1 = param_dict1['log_sigma'] D1_exp = np.exp(2 * param_dict1['log_sigma']) D1_inv = np.diag(1/D1_exp) sigma1_log_det = _get_log_determinant(D1, B1) #By the Woodbury formula D1_invB = np.dot(D1_inv, B1) I_BDB = np.eye(self._k) + np.dot(B1.T, D1_invB) I_BDB_inv = np.linalg.solve(I_BDB, np.identity(I_BDB.shape[0])) Sigma_inv = D1_inv - D1_invB @ I_BDB_inv @ D1_invB.T #KL divergence sigma_log_diff = sigma1_log_det - sigma0_log_det mean_sigma = mean_diff.T @ Sigma_inv @ mean_diff sigma_trace = _get_trace(D0_exp, B0, D1_exp, B1) return .5 * (sigma_log_diff - self.dim + mean_sigma + sigma_trace) def log_density(self,var_param, x): if x.ndim == 1: x = x[np.newaxis, :] param_dict = self._pattern.fold(var_param) mean = param_dict['mu'] B = param_dict['low_rank'] D = param_dict['log_sigma'] D_exp = np.exp(2*D) D_inv = np.diag(1/D_exp) sigma_log_det = _get_log_determinant(D, B) # By the Woodbury formula D_invB = D_inv @ B I_BDB = np.eye(self._k) + np.dot(B.T, D_invB) I_BDB_inv = np.linalg.solve(I_BDB, np.identity(I_BDB.shape[0])) sigma_inv = D_inv - D_invB @ I_BDB_inv @ D_invB.T # Compute the log density of the multivariate Gaussian distribution for each row of X diff = x - mean log_p = -0.5 * (self.dim * np.log(2 * np.pi) + sigma_log_det + np.sum(diff @ sigma_inv * diff, axis=1)) return log_p def mean_and_cov(self, var_param): param_dict = self._pattern.fold(var_param) B = param_dict['low_rank'] D_exp = np.exp(2 * param_dict['log_sigma']) return param_dict['mu'], np.dot(B, B.T) + np.diag(D_exp) def _pth_moment(self, var_param, p): param_dict = self._pattern.fold(var_param) D_exp = np.exp(2*param_dict['log_sigma']) B = param_dict['low_rank'] covariance = B @ B.T + np.diag(D_exp) #check later eigvals = np.linalg.eigvalsh(covariance) if p == 2: return np.sum(eigvals) else: # p == 4 return 2 * np.sum(eigvals ** 2) + np.sum(eigvals) ** 2 def supports_pth_moment(self, p): return p in [2, 4]