* BUGFIX: See code associated with scale_identity_multiplier

* BUGFIX:  See code associated with 'weight' inside sample_n
* Use parameterization `temperature` rather than `mix_scale`.
* Simplify documentation and link to arXiv paper for details
* Document that we allow `temperature` to have any shape broadcastable with `mix_loc`.  This is a mute point to some degree since we require K = 2 now.
* Add some tests

PiperOrigin-RevId: 183098275
This commit is contained in:
Ian Langmore 2018-01-24 09:49:47 -08:00 committed by TensorFlower Gardener
parent ffdae0a357
commit 7b62a71e2d
2 changed files with 256 additions and 188 deletions
tensorflow/contrib/distributions/python

View File

@ -27,6 +27,8 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
rng = np.random.RandomState(0)
class VectorDiffeomixtureTest( class VectorDiffeomixtureTest(
test_util.VectorDistributionTestHelpers, test.TestCase): test_util.VectorDistributionTestHelpers, test.TestCase):
@ -37,7 +39,7 @@ class VectorDiffeomixtureTest(
dims = 4 dims = 4
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]], mix_loc=[[0.], [1.]],
mix_scale=[1.], temperature=[1.],
distribution=normal_lib.Normal(0., 1.), distribution=normal_lib.Normal(0., 1.),
loc=[ loc=[
None, None,
@ -66,7 +68,7 @@ class VectorDiffeomixtureTest(
dims = 4 dims = 4
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]], mix_loc=[[0.], [1.]],
mix_scale=[1.], temperature=[1.],
distribution=normal_lib.Normal(1., 1.5), distribution=normal_lib.Normal(1., 1.5),
loc=[ loc=[
None, None,
@ -95,7 +97,7 @@ class VectorDiffeomixtureTest(
dims = 4 dims = 4
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]], mix_loc=[[0.], [1.]],
mix_scale=[1.], temperature=[1.],
distribution=normal_lib.Normal(0., 1.), distribution=normal_lib.Normal(0., 1.),
loc=[ loc=[
None, None,
@ -122,12 +124,39 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob( self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.01) sess.run, vdm, radius=4., center=2., rtol=0.01)
def testSampleProbConsistentBroadcastMixTwoBatchDims(self):
dims = 4
loc_1 = rng.randn(2, 3, dims).astype(np.float32)
with self.test_session() as sess:
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32),
temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
loc_1,
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=[np.float32(1.1)],
is_positive_definite=True),
] * 2,
validate_args=True)
# Ball centered at component0's mean.
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=2., center=0., rtol=0.01)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=3., center=loc_1, rtol=0.02)
def testMeanCovarianceNoBatch(self): def testMeanCovarianceNoBatch(self):
with self.test_session() as sess: with self.test_session() as sess:
dims = 3 dims = 3
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]], mix_loc=[[0.], [4.]],
mix_scale=[10.], temperature=[1 / 10.],
distribution=normal_lib.Normal(0., 1.), distribution=normal_lib.Normal(0., 1.),
loc=[ loc=[
np.float32([-2.]), np.float32([-2.]),
@ -147,12 +176,94 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_mean_covariance( self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08) sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self):
# As temperature decreases, this should approach a mixture of normals, with
# components at -2, 2.
with self.test_session() as sess:
dims = 1
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
temperature=[[2.], [1.], [0.2]],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
np.float32([2.]),
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=np.float32(0.5),
is_positive_definite=True),
] * 2, # Use the same scale for each component.
quadrature_size=8,
validate_args=True)
samps = vdm.sample(10000)
self.assertAllEqual((10000, 3, 1), samps.shape)
samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
# One characteristic of a discrete mixture (as opposed to a "smear") is
# that more weight is put near the component centers at -2, 2, and thus
# less weight is put near the origin.
prob_of_being_near_origin = (np.abs(samps_) < 1).mean(axis=0)
self.assertGreater(
prob_of_being_near_origin[0], prob_of_being_near_origin[1])
self.assertGreater(
prob_of_being_near_origin[1], prob_of_being_near_origin[2])
# Run this test as well, just because we can.
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self):
with self.test_session() as sess:
dims = 1
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[-1.], [0.], [1.]],
temperature=[0.5],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
np.float32([2.]),
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=np.float32(0.5),
is_positive_definite=True),
] * 2, # Use the same scale for each component.
quadrature_size=8,
validate_args=True)
samps = vdm.sample(10000)
self.assertAllEqual((10000, 3, 1), samps.shape)
samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
# One characteristic of putting more weight on a component is that the
# mean is closer to that component's mean.
# Get the mean for each batch member, the names signify the value of
# concentration for that batch member.
mean_neg1, mean_0, mean_1 = samps_.mean(axis=0)
# Since concentration is the concentration for component 0,
# concentration = -1 ==> more weight on component 1, which has mean = 2
# concentration = 0 ==> equal weight
# concentration = 1 ==> more weight on component 0, which has mean = -2
self.assertLess(-2, mean_1)
self.assertLess(mean_1, mean_0)
self.assertLess(mean_0, mean_neg1)
self.assertLess(mean_neg1, 2)
# Run this test as well, just because we can.
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
with self.test_session() as sess: with self.test_session() as sess:
dims = 3 dims = 3
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]], mix_loc=[[0.], [4.]],
mix_scale=[10.], temperature=[0.1],
distribution=normal_lib.Normal(-1., 1.5), distribution=normal_lib.Normal(-1., 1.5),
loc=[ loc=[
np.float32([-2.]), np.float32([-2.]),
@ -177,7 +288,7 @@ class VectorDiffeomixtureTest(
dims = 3 dims = 3
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]], mix_loc=[[0.], [4.]],
mix_scale=[10.], temperature=[0.1],
distribution=normal_lib.Normal(0., 1.), distribution=normal_lib.Normal(0., 1.),
loc=[ loc=[
np.float32([[-2.]]), np.float32([[-2.]]),
@ -205,7 +316,7 @@ class VectorDiffeomixtureTest(
dims = 4 dims = 4
vdm = vdm_lib.VectorDiffeomixture( vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.], mix_loc=[0.],
mix_scale=[1.], temperature=[0.1],
distribution=normal_lib.Normal(0., 1.), distribution=normal_lib.Normal(0., 1.),
loc=[ loc=[
None, None,
@ -229,29 +340,6 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob( self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.005) sess.run, vdm, radius=4., center=2., rtol=0.005)
# TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
# (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
# tested that the quadrature approach well-approximates the integral.
#
# To that end, consider adding these tests:
#
# Test1: In the limit of high mix_scale, this approximates a discrete mixture,
# and there are many discrete mixtures where we can explicitly compute
# mean/var, etc... So test1 would choose one of those discrete mixtures and
# show our mean/var/etc... is close to that.
#
# Test2: In the limit of low mix_scale, the a diffeomixture of Normal(-5, 1),
# Normal(5, 1) should (I believe...must check) should look almost like
# Uniform(-5, 5), and thus (i) .prob(x) should be about 1/10 for x in (-5, 5),
# and (ii) the first few moments should approximately match that of
# Uniform(-5, 5)
#
# Test3: If mix_loc is symmetric, then for any mix_scale, our
# quadrature-based diffeomixture of Normal(-1, 1), Normal(1, 1) should have
# mean zero, exactly.
# TODO(jvdillon): Add more tests which verify broadcasting.
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -50,20 +50,25 @@ __all__ = [
def quadrature_scheme_softmaxnormal_gauss_hermite( def quadrature_scheme_softmaxnormal_gauss_hermite(
loc, scale, quadrature_size, normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None): validate_args=False, name=None):
"""Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex. """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex.
A `SoftmaxNormal` random variable `Y` may be generated via
```
Y = SoftmaxCentered(X),
X = Normal(normal_loc, normal_scale)
```
Note: for a given `quadrature_size`, this method is generally less accurate Note: for a given `quadrature_size`, this method is generally less accurate
than `quadrature_scheme_softmaxnormal_quantiles`. than `quadrature_scheme_softmaxnormal_quantiles`.
Args: Args:
loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
Represents the `location` parameter of the SoftmaxNormal used for The location parameter of the Normal used to construct the SoftmaxNormal.
selecting one of the `K` affine transformations. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The scale parameter of the Normal used to construct the SoftmaxNormal.
Represents the `scale` parameter of the SoftmaxNormal used for
selecting one of the `K` affine transformations.
quadrature_size: Python `int` scalar representing the number of quadrature quadrature_size: Python `int` scalar representing the number of quadrature
points. points.
validate_args: Python `bool`, default `False`. When `True` distribution validate_args: Python `bool`, default `False`. When `True` distribution
@ -80,24 +85,25 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
associated with each grid point. associated with each grid point.
""" """
with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite", with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite",
[loc, scale]): [normal_loc, normal_scale]):
loc = ops.convert_to_tensor(loc, name="loc") normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
dt = loc.dtype.base_dtype dt = normal_loc.dtype.base_dtype
scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") normal_scale = ops.convert_to_tensor(
normal_scale, dtype=dt, name="normal_scale")
loc = maybe_check_quadrature_param(loc, "loc", validate_args) normal_scale = maybe_check_quadrature_param(
scale = maybe_check_quadrature_param(scale, "scale", validate_args) normal_scale, "normal_scale", validate_args)
grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
grid = grid.astype(loc.dtype.as_numpy_dtype) grid = grid.astype(dt.dtype.as_numpy_dtype)
probs = probs.astype(loc.dtype.as_numpy_dtype) probs = probs.astype(dt.dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1, keepdims=True) probs /= np.linalg.norm(probs, ord=1, keepdims=True)
probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) probs = ops.convert_to_tensor(probs, name="probs", dtype=dt)
grid = softmax( grid = softmax(
-distribution_util.pad( -distribution_util.pad(
(loc[..., array_ops.newaxis] + (normal_loc[..., array_ops.newaxis] +
np.sqrt(2.) * scale[..., array_ops.newaxis] * grid), np.sqrt(2.) * normal_scale[..., array_ops.newaxis] * grid),
axis=-2, axis=-2,
front=True), front=True),
axis=-2) # shape: [B, components, deg] axis=-2) # shape: [B, components, deg]
@ -106,18 +112,23 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
def quadrature_scheme_softmaxnormal_quantiles( def quadrature_scheme_softmaxnormal_quantiles(
loc, scale, quadrature_size, normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None): validate_args=False, name=None):
"""Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.
A `SoftmaxNormal` random variable `Y` may be generated via
```
Y = SoftmaxCentered(X),
X = Normal(normal_loc, normal_scale)
```
Args: Args:
loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
Represents the `location` parameter of the SoftmaxNormal used for The location parameter of the Normal used to construct the SoftmaxNormal.
selecting one of the `K` affine transformations. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The scale parameter of the Normal used to construct the SoftmaxNormal.
Represents the `scale` parameter of the SoftmaxNormal used for quadrature_size: Python `int` scalar representing the number of quadrature
selecting one of the `K` affine transformations.
quadrature_size: Python scalar `int` representing the number of quadrature
points. points.
validate_args: Python `bool`, default `False`. When `True` distribution validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime parameters are checked for validity despite possibly degrading runtime
@ -132,15 +143,17 @@ def quadrature_scheme_softmaxnormal_quantiles(
probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
associated with each grid point. associated with each grid point.
""" """
with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]): with ops.name_scope(name, "softmax_normal_grid_and_probs",
loc = ops.convert_to_tensor(loc, name="loc") [normal_loc, normal_scale]):
dt = loc.dtype.base_dtype normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") dt = normal_loc.dtype.base_dtype
normal_scale = ops.convert_to_tensor(
normal_scale, dtype=dt, name="normal_scale")
loc = maybe_check_quadrature_param(loc, "loc", validate_args) normal_scale = maybe_check_quadrature_param(
scale = maybe_check_quadrature_param(scale, "scale", validate_args) normal_scale, "normal_scale", validate_args)
dist = normal_lib.Normal(loc=loc, scale=scale) dist = normal_lib.Normal(loc=normal_loc, scale=normal_scale)
def _get_batch_ndims(): def _get_batch_ndims():
"""Helper to get dist.batch_shape.ndims, statically if possible.""" """Helper to get dist.batch_shape.ndims, statically if possible."""
@ -195,114 +208,51 @@ def quadrature_scheme_softmaxnormal_quantiles(
class VectorDiffeomixture(distribution_lib.Distribution): class VectorDiffeomixture(distribution_lib.Distribution):
"""VectorDiffeomixture distribution. """VectorDiffeomixture distribution.
The VectorDiffeomixture is an approximation to a [compound distribution]( A vector diffeomixture (VDM) is a distribution parameterized by a convex
https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., combination of `K` component `loc` vectors, `loc[k], k = 0,...,K-1`, and `K`
`scale` matrices `scale[k], k = 0,..., K-1`. It approximates the following
[compound distribution]
(https://en.wikipedia.org/wiki/Compound_probability_distribution)
```none ```none
p(x) = int_{X} q(x | v) p(v) dv p(x) = int p(x | z) p(z) dz,
= lim_{Q->infty} sum{ prob[i] q(x | loc=sum_k^K lambda[k;i] loc[k], where z is in the K-simplex, and
scale=sum_k^K lambda[k;i] scale[k]) p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
: i=0, ..., Q-1 }
``` ```
where `q(x | v)` is a vector version of the `distribution` argument and `p(v)` The integral `int p(x | z) p(z) dz` is approximated with a quadrature scheme
is a SoftmaxNormal parameterized by `mix_loc` and `mix_scale`. The adapted to the mixture density `p(z)`. The `N` quadrature points `z_{N, n}`
vector-ization of `distribution` entails an affine transformation of iid and weights `w_{N, n}` (which are non-negative and sum to 1) are chosen
samples from `distribution`. The `prob` term is from quadrature and such that
`lambda[k] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[k])` where the
`grid` points correspond to the `prob`s.
In the non-approximation case, a draw from the mixture distribution (the ```q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x)```
"prior") represents the convex weights for different affine transformations.
I.e., draw a mixing vector `v` (from the `K-1`-simplex) and let the final
sample be: `y = (sum_k^K v[k] scale[k]) @ x + (sum_k^K v[k] loc[k])` where `@`
denotes matrix multiplication. However, the non-approximate distribution does
not have an analytical probability density function (pdf). Therefore the
`VectorDiffeomixture` class implements an approximation based on
[numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
Note: although the `VectorDiffeomixture` is approximately the
`SoftmaxNormal-Distribution` compound distribution, it is itself a valid
distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
are all mutually consistent.
#### Intended Use as `N --> infinity`.
This distribution is noteworthy because it implements a mixture of Since `q_N(x)` is in fact a mixture (of `N` points), we may sample from
`Vector`-ized distributions yet has samples differentiable in the `q_N` exactly. It is important to note that the VDM is *defined* as `q_N`
distribution's parameters (aka "reparameterized"). It has an analytical above, and *not* `p(x)`. Therefore, sampling and pdf may be implemented as
density function with `O(dKQ)` complexity. `d` is the vector dimensionality, exact (up to floating point error) methods.
`K` is the number of components, and `Q` is the number of quadrature points.
These properties make it well-suited for Bayesian Variational Inference, i.e.,
as a surrogate family for the posterior.
For large values of `mix_scale`, the `VectorDistribution` behaves increasingly A common choice for the conditional `p(x | z)` is a multivariate Normal.
like a discrete mixture. (In most cases this limit is only achievable by also
increasing the quadrature polynomial degree, `Q`.)
The term `Vector` is consistent with similar named Tensorflow `Distribution`s. The implemented marginal `p(z)` is the `SoftmaxNormal`, which is a
For more details, see the "About `Vector` distributions in Tensorflow." `K-1` dimensional Normal transformed by a `SoftmaxCentered` bijector, making
section. it a density on the `K`-simplex. That is,
The term `Diffeomixture` is a portmanteau of ```
[diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism) and [compound Z = SoftmaxCentered(X),
mixture](https://en.wikipedia.org/wiki/Compound_probability_distribution). For X = Normal(mix_loc / temperature, 1 / temperature)
more details, see the "About `Diffeomixture`s and reparametrization.`"
section.
#### Mathematical Details
The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
[compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution).
Using variable-substitution and [numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
redefine the distribution to be a parameter-less convex combination of `K`
different affine combinations of a `d` iid samples from `distribution`.
That is, defined over `R**d` this distribution is parameterized by a
(batch of) length-`K` `mix_loc` and `mix_scale` vectors, a length-`K` list of
(a batch of) length-`d` `loc` vectors, and a length-`K` list of `scale`
`LinearOperator`s each operating on a (batch of) length-`d` vector space.
Finally, a `distribution` parameter specifies the underlying base distribution
which is "lifted" to become multivariate ("lifting" is the same concept as in
`TransformedDistribution`).
The probability density function (pdf) is,
```none
pdf(y; mix_loc, mix_scale, loc, scale, phi)
= sum{ prob[i] phi(f_inverse(x; i)) / abs(det(interp_scale[i]))
: i=0, ..., Q-1 }
``` ```
where, `phi` is the base distribution pdf, and, The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of
the quantiles of `p(z)` (generalized quantiles if `K > 2`).
```none See [1] for more details.
f_inverse(x; i) = inv(interp_scale[i]) @ (x - interp_loc[i])
interp_loc[i] = sum{ lambda[k; i] loc[k] : k=0, ..., K-1 }
interp_scale[i] = sum{ lambda[k; i] scale[k] : k=0, ..., K-1 }
```
and, [1]. "Quadrature Compound: An approximating family of distributions"
Joshua Dillon, Ian Langmore, arXiv preprints
```none https://arxiv.org/abs/1801.03080
grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
prob[k] = weight[k] / sqrt(pi)
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
```
The distribution corresponding to `phi` must be a scalar-batch, scalar-event
distribution. Typically it is reparameterized. If not, it must be a function
of non-trainable parameters.
WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
variables, then the gradient is not guaranteed correct!
#### About `Vector` distributions in TensorFlow. #### About `Vector` distributions in TensorFlow.
@ -310,12 +260,11 @@ class VectorDiffeomixture(distribution_lib.Distribution):
particularly useful in [variational Bayesian particularly useful in [variational Bayesian
methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods). methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods).
Conditioned on a draw from the SoftmaxNormal, `Y|v` is a vector whose Conditioned on a draw from the SoftmaxNormal, `X|z` is a vector whose
components are linear combinations of affine transformations, thus is itself components are linear combinations of affine transformations, thus is itself
an affine transformation. Therefore `Y|v` lives in the vector space generated an affine transformation.
by vectors of affine-transformed distributions.
Note: The marginals `Y_1|v, ..., Y_d|v` are *not* generally identical to some Note: The marginals `X_1|v, ..., X_d|v` are *not* generally identical to some
parameterization of `distribution`. This is due to the fact that the sum of parameterization of `distribution`. This is due to the fact that the sum of
draws from `distribution` are not generally itself the same `distribution`. draws from `distribution` are not generally itself the same `distribution`.
@ -331,12 +280,16 @@ class VectorDiffeomixture(distribution_lib.Distribution):
optimize Monte-Carlo objectives. Such objectives are a finite-sample optimize Monte-Carlo objectives. Such objectives are a finite-sample
approximation of an expectation and arise throughout scientific computing. approximation of an expectation and arise throughout scientific computing.
WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
variables, then the gradient is not guaranteed correct!
#### Examples #### Examples
```python ```python
tfd = tf.contrib.distributions tfd = tf.contrib.distributions
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine # another with mix_loc=[1]. In both cases, `K=2` and the affine
# transformations involve: # transformations involve:
# k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity
@ -344,7 +297,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
dims = 5 dims = 5
vdm = tfd.VectorDiffeomixture( vdm = tfd.VectorDiffeomixture(
mix_loc=[[0.], [1]], mix_loc=[[0.], [1]],
mix_scale=[1.], temperature=[1.],
distribution=tfd.Normal(loc=0., scale=1.), distribution=tfd.Normal(loc=0., scale=1.),
loc=[ loc=[
None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. None, # Equivalent to `np.zeros(dims, dtype=np.float32)`.
@ -364,7 +317,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
def __init__(self, def __init__(self,
mix_loc, mix_loc,
mix_scale, temperature,
distribution, distribution,
loc=None, loc=None,
scale=None, scale=None,
@ -373,15 +326,24 @@ class VectorDiffeomixture(distribution_lib.Distribution):
validate_args=False, validate_args=False,
allow_nan_stats=True, allow_nan_stats=True,
name="VectorDiffeomixture"): name="VectorDiffeomixture"):
"""Constructs the VectorDiffeomixture on `R**d`. """Constructs the VectorDiffeomixture on `R^d`.
The vector diffeomixture (VDM) approximates the compound distribution
```none
p(x) = int p(x | z) p(z) dz,
where z is in the K-simplex, and
p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
```
Args: Args:
mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
the `location` parameter of the SoftmaxNormal used for selecting one of In terms of samples, larger `mix_loc[..., k]` ==>
the `K` affine transformations. `Z` is more likely to put more weight on its `kth` component.
mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
Represents the `scale` parameter of the SoftmaxNormal used for selecting In terms of samples, smaller `temperature` means one component is more
one of the `K` affine transformations. likely to dominate. I.e., smaller `temperature` makes the VDM look more
like a standard mixture of `K` components.
distribution: `tf.Distribution`-like instance. Distribution from which `d` distribution: `tf.Distribution`-like instance. Distribution from which `d`
iid samples are used as input to the selected affine transformation. iid samples are used as input to the selected affine transformation.
Must be a scalar-batch, scalar-event distribution. Typically Must be a scalar-batch, scalar-event distribution. Typically
@ -401,8 +363,9 @@ class VectorDiffeomixture(distribution_lib.Distribution):
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
quadrature_size: Python `int` scalar representing number of quadrature_size: Python `int` scalar representing number of
quadrature points. quadrature points. Larger `quadrature_size` means `q_N(x)` better
quadrature_fn: Python callable taking `mix_loc`, `mix_scale`, approximates `p(x)`.
quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
`quadrature_size`, `validate_args` and returning `tuple(grid, probs)` `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
representing the SoftmaxNormal grid and corresponding normalized weight. representing the SoftmaxNormal grid and corresponding normalized weight.
normalized) weight. normalized) weight.
@ -430,7 +393,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_event`. ValueError: if `not distribution.is_scalar_event`.
""" """
parameters = locals() parameters = locals()
with ops.name_scope(name, values=[mix_loc, mix_scale]): with ops.name_scope(name, values=[mix_loc, temperature]):
if not scale or len(scale) < 2: if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale " raise ValueError("Must specify list (or list-like object) of scale "
"LinearOperators, one for each component with " "LinearOperators, one for each component with "
@ -473,8 +436,15 @@ class VectorDiffeomixture(distribution_lib.Distribution):
raise NotImplementedError("Currently only bimixtures are supported; " raise NotImplementedError("Currently only bimixtures are supported; "
"len(scale)={} is not 2.".format(len(scale))) "len(scale)={} is not 2.".format(len(scale)))
mix_loc = ops.convert_to_tensor(
mix_loc, dtype=dtype, name="mix_loc")
temperature = ops.convert_to_tensor(
temperature, dtype=dtype, name="temperature")
self._grid, probs = tuple(quadrature_fn( self._grid, probs = tuple(quadrature_fn(
mix_loc, mix_scale, quadrature_size, validate_args)) mix_loc / temperature,
1. / temperature,
quadrature_size,
validate_args))
# Note: by creating the logits as `log(prob)` we ensure that # Note: by creating the logits as `log(prob)` we ensure that
# `self.mixture_distribution.logits` is equivalent to # `self.mixture_distribution.logits` is equivalent to
@ -618,7 +588,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
weight = array_ops.gather( weight = array_ops.gather(
array_ops.reshape(self.grid, shape=[-1]), array_ops.reshape(self.grid, shape=[-1]),
ids + offset) ids + offset)
weight = weight[..., array_ops.newaxis] # At this point, weight flattened all batch dims into one.
# We also need to append a singleton to broadcast with event dims.
if self.batch_shape.is_fully_defined():
new_shape = [-1] + self.batch_shape.as_list() + [1]
else:
new_shape = array_ops.concat(
([-1], self.batch_shape_tensor(), [1]), axis=0)
weight = array_ops.reshape(weight, shape=new_shape)
if len(x) != 2: if len(x) != 2:
# We actually should have already triggered this exception. However as a # We actually should have already triggered this exception. However as a
@ -686,7 +663,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
# To compute E[Cov(Z|V)], we'll add matrices within three categories: # To compute E[Cov(Z|V)], we'll add matrices within three categories:
# scaled-identity, diagonal, and full. Then we'll combine these at the end. # scaled-identity, diagonal, and full. Then we'll combine these at the end.
scaled_identity = None scale_identity_multiplier = None
diag = None diag = None
full = None full = None
@ -694,10 +671,12 @@ class VectorDiffeomixture(distribution_lib.Distribution):
s = aff.scale # Just in case aff.scale has side-effects, we'll call once. s = aff.scale # Just in case aff.scale has side-effects, we'll call once.
if (s is None if (s is None
or isinstance(s, linop_identity_lib.LinearOperatorIdentity)): or isinstance(s, linop_identity_lib.LinearOperatorIdentity)):
scaled_identity = add(scaled_identity, p[..., k, array_ops.newaxis]) scale_identity_multiplier = add(scale_identity_multiplier,
p[..., k, array_ops.newaxis])
elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity): elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity):
scaled_identity = add(scaled_identity, (p[..., k, array_ops.newaxis] * scale_identity_multiplier = add(
math_ops.square(s.multiplier))) scale_identity_multiplier,
(p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier)))
elif isinstance(s, linop_diag_lib.LinearOperatorDiag): elif isinstance(s, linop_diag_lib.LinearOperatorDiag):
diag = add(diag, (p[..., k, array_ops.newaxis] * diag = add(diag, (p[..., k, array_ops.newaxis] *
math_ops.square(s.diag_part()))) math_ops.square(s.diag_part())))
@ -709,12 +688,13 @@ class VectorDiffeomixture(distribution_lib.Distribution):
full = add(full, x) full = add(full, x)
# We must now account for the fact that the base distribution might have a # We must now account for the fact that the base distribution might have a
# non-unity variance. Recall that `Cov(SX+m) = S.T Cov(X) S = S.T S Var(X)`. # non-unity variance. Recall that, since X ~ iid Law(X_0),
# `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`.
# We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid # We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid
# samples from a scalar-event distribution. # samples from a scalar-event distribution.
v = self.distribution.variance() v = self.distribution.variance()
if scaled_identity is not None: if scale_identity_multiplier is not None:
scaled_identity *= v scale_identity_multiplier *= v
if diag is not None: if diag is not None:
diag *= v[..., array_ops.newaxis] diag *= v[..., array_ops.newaxis]
if full is not None: if full is not None:
@ -723,10 +703,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
if diag_only: if diag_only:
# Apparently we don't need the full matrix, just the diagonal. # Apparently we don't need the full matrix, just the diagonal.
r = add(diag, full) r = add(diag, full)
if r is None and scaled_identity is not None: if r is None and scale_identity_multiplier is not None:
ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype) ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype)
return scaled_identity * ones return scale_identity_multiplier[..., array_ops.newaxis] * ones
return add(r, scaled_identity) return add(r, scale_identity_multiplier)
# `None` indicates we don't know if the result is positive-definite. # `None` indicates we don't know if the result is positive-definite.
is_positive_definite = (True if all(aff.scale.is_positive_definite is_positive_definite = (True if all(aff.scale.is_positive_definite
@ -742,10 +722,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
to_add.append(linop_full_lib.LinearOperatorFullMatrix( to_add.append(linop_full_lib.LinearOperatorFullMatrix(
matrix=full, matrix=full,
is_positive_definite=is_positive_definite)) is_positive_definite=is_positive_definite))
if scaled_identity is not None: if scale_identity_multiplier is not None:
to_add.append(linop_identity_lib.LinearOperatorScaledIdentity( to_add.append(linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=self.event_shape_tensor()[0], num_rows=self.event_shape_tensor()[0],
multiplier=scaled_identity, multiplier=scale_identity_multiplier,
is_positive_definite=is_positive_definite)) is_positive_definite=is_positive_definite))
return (linop_add_lib.add_operators(to_add)[0].to_dense() return (linop_add_lib.add_operators(to_add)[0].to_dense()