* BUGFIX: See code associated with scale_identity_multiplier

* BUGFIX:  See code associated with 'weight' inside sample_n
* Use parameterization `temperature` rather than `mix_scale`.
* Simplify documentation and link to arXiv paper for details
* Document that we allow `temperature` to have any shape broadcastable with `mix_loc`.  This is a mute point to some degree since we require K = 2 now.
* Add some tests

PiperOrigin-RevId: 183098275
This commit is contained in:
Ian Langmore 2018-01-24 09:49:47 -08:00 committed by TensorFlower Gardener
parent ffdae0a357
commit 7b62a71e2d
2 changed files with 256 additions and 188 deletions
tensorflow/contrib/distributions/python

View File

@ -27,6 +27,8 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
from tensorflow.python.platform import test
rng = np.random.RandomState(0)
class VectorDiffeomixtureTest(
test_util.VectorDistributionTestHelpers, test.TestCase):
@ -37,7 +39,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
mix_scale=[1.],
temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@ -66,7 +68,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
mix_scale=[1.],
temperature=[1.],
distribution=normal_lib.Normal(1., 1.5),
loc=[
None,
@ -95,7 +97,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
mix_scale=[1.],
temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@ -122,12 +124,39 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.01)
def testSampleProbConsistentBroadcastMixTwoBatchDims(self):
dims = 4
loc_1 = rng.randn(2, 3, dims).astype(np.float32)
with self.test_session() as sess:
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32),
temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
loc_1,
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=[np.float32(1.1)],
is_positive_definite=True),
] * 2,
validate_args=True)
# Ball centered at component0's mean.
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=2., center=0., rtol=0.01)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=3., center=loc_1, rtol=0.02)
def testMeanCovarianceNoBatch(self):
with self.test_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
mix_scale=[10.],
temperature=[1 / 10.],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
@ -147,12 +176,94 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self):
# As temperature decreases, this should approach a mixture of normals, with
# components at -2, 2.
with self.test_session() as sess:
dims = 1
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
temperature=[[2.], [1.], [0.2]],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
np.float32([2.]),
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=np.float32(0.5),
is_positive_definite=True),
] * 2, # Use the same scale for each component.
quadrature_size=8,
validate_args=True)
samps = vdm.sample(10000)
self.assertAllEqual((10000, 3, 1), samps.shape)
samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
# One characteristic of a discrete mixture (as opposed to a "smear") is
# that more weight is put near the component centers at -2, 2, and thus
# less weight is put near the origin.
prob_of_being_near_origin = (np.abs(samps_) < 1).mean(axis=0)
self.assertGreater(
prob_of_being_near_origin[0], prob_of_being_near_origin[1])
self.assertGreater(
prob_of_being_near_origin[1], prob_of_being_near_origin[2])
# Run this test as well, just because we can.
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self):
with self.test_session() as sess:
dims = 1
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[-1.], [0.], [1.]],
temperature=[0.5],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
np.float32([2.]),
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=np.float32(0.5),
is_positive_definite=True),
] * 2, # Use the same scale for each component.
quadrature_size=8,
validate_args=True)
samps = vdm.sample(10000)
self.assertAllEqual((10000, 3, 1), samps.shape)
samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
# One characteristic of putting more weight on a component is that the
# mean is closer to that component's mean.
# Get the mean for each batch member, the names signify the value of
# concentration for that batch member.
mean_neg1, mean_0, mean_1 = samps_.mean(axis=0)
# Since concentration is the concentration for component 0,
# concentration = -1 ==> more weight on component 1, which has mean = 2
# concentration = 0 ==> equal weight
# concentration = 1 ==> more weight on component 0, which has mean = -2
self.assertLess(-2, mean_1)
self.assertLess(mean_1, mean_0)
self.assertLess(mean_0, mean_neg1)
self.assertLess(mean_neg1, 2)
# Run this test as well, just because we can.
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
with self.test_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
mix_scale=[10.],
temperature=[0.1],
distribution=normal_lib.Normal(-1., 1.5),
loc=[
np.float32([-2.]),
@ -177,7 +288,7 @@ class VectorDiffeomixtureTest(
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
mix_scale=[10.],
temperature=[0.1],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([[-2.]]),
@ -205,7 +316,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
mix_scale=[1.],
temperature=[0.1],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@ -229,29 +340,6 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.005)
# TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
# (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
# tested that the quadrature approach well-approximates the integral.
#
# To that end, consider adding these tests:
#
# Test1: In the limit of high mix_scale, this approximates a discrete mixture,
# and there are many discrete mixtures where we can explicitly compute
# mean/var, etc... So test1 would choose one of those discrete mixtures and
# show our mean/var/etc... is close to that.
#
# Test2: In the limit of low mix_scale, the a diffeomixture of Normal(-5, 1),
# Normal(5, 1) should (I believe...must check) should look almost like
# Uniform(-5, 5), and thus (i) .prob(x) should be about 1/10 for x in (-5, 5),
# and (ii) the first few moments should approximately match that of
# Uniform(-5, 5)
#
# Test3: If mix_loc is symmetric, then for any mix_scale, our
# quadrature-based diffeomixture of Normal(-1, 1), Normal(1, 1) should have
# mean zero, exactly.
# TODO(jvdillon): Add more tests which verify broadcasting.
if __name__ == "__main__":
test.main()

View File

@ -50,20 +50,25 @@ __all__ = [
def quadrature_scheme_softmaxnormal_gauss_hermite(
loc, scale, quadrature_size,
normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
"""Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex.
A `SoftmaxNormal` random variable `Y` may be generated via
```
Y = SoftmaxCentered(X),
X = Normal(normal_loc, normal_scale)
```
Note: for a given `quadrature_size`, this method is generally less accurate
than `quadrature_scheme_softmaxnormal_quantiles`.
Args:
loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
Represents the `location` parameter of the SoftmaxNormal used for
selecting one of the `K` affine transformations.
scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
Represents the `scale` parameter of the SoftmaxNormal used for
selecting one of the `K` affine transformations.
normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
The location parameter of the Normal used to construct the SoftmaxNormal.
normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
The scale parameter of the Normal used to construct the SoftmaxNormal.
quadrature_size: Python `int` scalar representing the number of quadrature
points.
validate_args: Python `bool`, default `False`. When `True` distribution
@ -80,24 +85,25 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
associated with each grid point.
"""
with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite",
[loc, scale]):
loc = ops.convert_to_tensor(loc, name="loc")
dt = loc.dtype.base_dtype
scale = ops.convert_to_tensor(scale, dtype=dt, name="scale")
[normal_loc, normal_scale]):
normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
dt = normal_loc.dtype.base_dtype
normal_scale = ops.convert_to_tensor(
normal_scale, dtype=dt, name="normal_scale")
loc = maybe_check_quadrature_param(loc, "loc", validate_args)
scale = maybe_check_quadrature_param(scale, "scale", validate_args)
normal_scale = maybe_check_quadrature_param(
normal_scale, "normal_scale", validate_args)
grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
grid = grid.astype(loc.dtype.as_numpy_dtype)
probs = probs.astype(loc.dtype.as_numpy_dtype)
grid = grid.astype(dt.dtype.as_numpy_dtype)
probs = probs.astype(dt.dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype)
probs = ops.convert_to_tensor(probs, name="probs", dtype=dt)
grid = softmax(
-distribution_util.pad(
(loc[..., array_ops.newaxis] +
np.sqrt(2.) * scale[..., array_ops.newaxis] * grid),
(normal_loc[..., array_ops.newaxis] +
np.sqrt(2.) * normal_scale[..., array_ops.newaxis] * grid),
axis=-2,
front=True),
axis=-2) # shape: [B, components, deg]
@ -106,18 +112,23 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
def quadrature_scheme_softmaxnormal_quantiles(
loc, scale, quadrature_size,
normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
"""Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.
A `SoftmaxNormal` random variable `Y` may be generated via
```
Y = SoftmaxCentered(X),
X = Normal(normal_loc, normal_scale)
```
Args:
loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
Represents the `location` parameter of the SoftmaxNormal used for
selecting one of the `K` affine transformations.
scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
Represents the `scale` parameter of the SoftmaxNormal used for
selecting one of the `K` affine transformations.
quadrature_size: Python scalar `int` representing the number of quadrature
normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
The location parameter of the Normal used to construct the SoftmaxNormal.
normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
The scale parameter of the Normal used to construct the SoftmaxNormal.
quadrature_size: Python `int` scalar representing the number of quadrature
points.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
@ -132,15 +143,17 @@ def quadrature_scheme_softmaxnormal_quantiles(
probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
associated with each grid point.
"""
with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]):
loc = ops.convert_to_tensor(loc, name="loc")
dt = loc.dtype.base_dtype
scale = ops.convert_to_tensor(scale, dtype=dt, name="scale")
with ops.name_scope(name, "softmax_normal_grid_and_probs",
[normal_loc, normal_scale]):
normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
dt = normal_loc.dtype.base_dtype
normal_scale = ops.convert_to_tensor(
normal_scale, dtype=dt, name="normal_scale")
loc = maybe_check_quadrature_param(loc, "loc", validate_args)
scale = maybe_check_quadrature_param(scale, "scale", validate_args)
normal_scale = maybe_check_quadrature_param(
normal_scale, "normal_scale", validate_args)
dist = normal_lib.Normal(loc=loc, scale=scale)
dist = normal_lib.Normal(loc=normal_loc, scale=normal_scale)
def _get_batch_ndims():
"""Helper to get dist.batch_shape.ndims, statically if possible."""
@ -195,114 +208,51 @@ def quadrature_scheme_softmaxnormal_quantiles(
class VectorDiffeomixture(distribution_lib.Distribution):
"""VectorDiffeomixture distribution.
The VectorDiffeomixture is an approximation to a [compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e.,
A vector diffeomixture (VDM) is a distribution parameterized by a convex
combination of `K` component `loc` vectors, `loc[k], k = 0,...,K-1`, and `K`
`scale` matrices `scale[k], k = 0,..., K-1`. It approximates the following
[compound distribution]
(https://en.wikipedia.org/wiki/Compound_probability_distribution)
```none
p(x) = int_{X} q(x | v) p(v) dv
= lim_{Q->infty} sum{ prob[i] q(x | loc=sum_k^K lambda[k;i] loc[k],
scale=sum_k^K lambda[k;i] scale[k])
: i=0, ..., Q-1 }
p(x) = int p(x | z) p(z) dz,
where z is in the K-simplex, and
p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
```
where `q(x | v)` is a vector version of the `distribution` argument and `p(v)`
is a SoftmaxNormal parameterized by `mix_loc` and `mix_scale`. The
vector-ization of `distribution` entails an affine transformation of iid
samples from `distribution`. The `prob` term is from quadrature and
`lambda[k] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[k])` where the
`grid` points correspond to the `prob`s.
The integral `int p(x | z) p(z) dz` is approximated with a quadrature scheme
adapted to the mixture density `p(z)`. The `N` quadrature points `z_{N, n}`
and weights `w_{N, n}` (which are non-negative and sum to 1) are chosen
such that
In the non-approximation case, a draw from the mixture distribution (the
"prior") represents the convex weights for different affine transformations.
I.e., draw a mixing vector `v` (from the `K-1`-simplex) and let the final
sample be: `y = (sum_k^K v[k] scale[k]) @ x + (sum_k^K v[k] loc[k])` where `@`
denotes matrix multiplication. However, the non-approximate distribution does
not have an analytical probability density function (pdf). Therefore the
`VectorDiffeomixture` class implements an approximation based on
[numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
Note: although the `VectorDiffeomixture` is approximately the
`SoftmaxNormal-Distribution` compound distribution, it is itself a valid
distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
are all mutually consistent.
```q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x)```
#### Intended Use
as `N --> infinity`.
This distribution is noteworthy because it implements a mixture of
`Vector`-ized distributions yet has samples differentiable in the
distribution's parameters (aka "reparameterized"). It has an analytical
density function with `O(dKQ)` complexity. `d` is the vector dimensionality,
`K` is the number of components, and `Q` is the number of quadrature points.
These properties make it well-suited for Bayesian Variational Inference, i.e.,
as a surrogate family for the posterior.
Since `q_N(x)` is in fact a mixture (of `N` points), we may sample from
`q_N` exactly. It is important to note that the VDM is *defined* as `q_N`
above, and *not* `p(x)`. Therefore, sampling and pdf may be implemented as
exact (up to floating point error) methods.
For large values of `mix_scale`, the `VectorDistribution` behaves increasingly
like a discrete mixture. (In most cases this limit is only achievable by also
increasing the quadrature polynomial degree, `Q`.)
A common choice for the conditional `p(x | z)` is a multivariate Normal.
The term `Vector` is consistent with similar named Tensorflow `Distribution`s.
For more details, see the "About `Vector` distributions in Tensorflow."
section.
The implemented marginal `p(z)` is the `SoftmaxNormal`, which is a
`K-1` dimensional Normal transformed by a `SoftmaxCentered` bijector, making
it a density on the `K`-simplex. That is,
The term `Diffeomixture` is a portmanteau of
[diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism) and [compound
mixture](https://en.wikipedia.org/wiki/Compound_probability_distribution). For
more details, see the "About `Diffeomixture`s and reparametrization.`"
section.
#### Mathematical Details
The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
[compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution).
Using variable-substitution and [numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
redefine the distribution to be a parameter-less convex combination of `K`
different affine combinations of a `d` iid samples from `distribution`.
That is, defined over `R**d` this distribution is parameterized by a
(batch of) length-`K` `mix_loc` and `mix_scale` vectors, a length-`K` list of
(a batch of) length-`d` `loc` vectors, and a length-`K` list of `scale`
`LinearOperator`s each operating on a (batch of) length-`d` vector space.
Finally, a `distribution` parameter specifies the underlying base distribution
which is "lifted" to become multivariate ("lifting" is the same concept as in
`TransformedDistribution`).
The probability density function (pdf) is,
```none
pdf(y; mix_loc, mix_scale, loc, scale, phi)
= sum{ prob[i] phi(f_inverse(x; i)) / abs(det(interp_scale[i]))
: i=0, ..., Q-1 }
```
Z = SoftmaxCentered(X),
X = Normal(mix_loc / temperature, 1 / temperature)
```
where, `phi` is the base distribution pdf, and,
The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of
the quantiles of `p(z)` (generalized quantiles if `K > 2`).
```none
f_inverse(x; i) = inv(interp_scale[i]) @ (x - interp_loc[i])
interp_loc[i] = sum{ lambda[k; i] loc[k] : k=0, ..., K-1 }
interp_scale[i] = sum{ lambda[k; i] scale[k] : k=0, ..., K-1 }
```
See [1] for more details.
and,
```none
grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
prob[k] = weight[k] / sqrt(pi)
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
```
The distribution corresponding to `phi` must be a scalar-batch, scalar-event
distribution. Typically it is reparameterized. If not, it must be a function
of non-trainable parameters.
WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
variables, then the gradient is not guaranteed correct!
[1]. "Quadrature Compound: An approximating family of distributions"
Joshua Dillon, Ian Langmore, arXiv preprints
https://arxiv.org/abs/1801.03080
#### About `Vector` distributions in TensorFlow.
@ -310,12 +260,11 @@ class VectorDiffeomixture(distribution_lib.Distribution):
particularly useful in [variational Bayesian
methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods).
Conditioned on a draw from the SoftmaxNormal, `Y|v` is a vector whose
Conditioned on a draw from the SoftmaxNormal, `X|z` is a vector whose
components are linear combinations of affine transformations, thus is itself
an affine transformation. Therefore `Y|v` lives in the vector space generated
by vectors of affine-transformed distributions.
an affine transformation.
Note: The marginals `Y_1|v, ..., Y_d|v` are *not* generally identical to some
Note: The marginals `X_1|v, ..., X_d|v` are *not* generally identical to some
parameterization of `distribution`. This is due to the fact that the sum of
draws from `distribution` are not generally itself the same `distribution`.
@ -331,12 +280,16 @@ class VectorDiffeomixture(distribution_lib.Distribution):
optimize Monte-Carlo objectives. Such objectives are a finite-sample
approximation of an expectation and arise throughout scientific computing.
WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
variables, then the gradient is not guaranteed correct!
#### Examples
```python
tfd = tf.contrib.distributions
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine
# transformations involve:
# k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity
@ -344,7 +297,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
dims = 5
vdm = tfd.VectorDiffeomixture(
mix_loc=[[0.], [1]],
mix_scale=[1.],
temperature=[1.],
distribution=tfd.Normal(loc=0., scale=1.),
loc=[
None, # Equivalent to `np.zeros(dims, dtype=np.float32)`.
@ -364,7 +317,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
def __init__(self,
mix_loc,
mix_scale,
temperature,
distribution,
loc=None,
scale=None,
@ -373,15 +326,24 @@ class VectorDiffeomixture(distribution_lib.Distribution):
validate_args=False,
allow_nan_stats=True,
name="VectorDiffeomixture"):
"""Constructs the VectorDiffeomixture on `R**d`.
"""Constructs the VectorDiffeomixture on `R^d`.
The vector diffeomixture (VDM) approximates the compound distribution
```none
p(x) = int p(x | z) p(z) dz,
where z is in the K-simplex, and
p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
```
Args:
mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents
the `location` parameter of the SoftmaxNormal used for selecting one of
the `K` affine transformations.
mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
Represents the `scale` parameter of the SoftmaxNormal used for selecting
one of the `K` affine transformations.
mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
In terms of samples, larger `mix_loc[..., k]` ==>
`Z` is more likely to put more weight on its `kth` component.
temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
In terms of samples, smaller `temperature` means one component is more
likely to dominate. I.e., smaller `temperature` makes the VDM look more
like a standard mixture of `K` components.
distribution: `tf.Distribution`-like instance. Distribution from which `d`
iid samples are used as input to the selected affine transformation.
Must be a scalar-batch, scalar-event distribution. Typically
@ -401,8 +363,9 @@ class VectorDiffeomixture(distribution_lib.Distribution):
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
quadrature_size: Python `int` scalar representing number of
quadrature points.
quadrature_fn: Python callable taking `mix_loc`, `mix_scale`,
quadrature points. Larger `quadrature_size` means `q_N(x)` better
approximates `p(x)`.
quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
`quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
representing the SoftmaxNormal grid and corresponding normalized weight.
normalized) weight.
@ -430,7 +393,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_event`.
"""
parameters = locals()
with ops.name_scope(name, values=[mix_loc, mix_scale]):
with ops.name_scope(name, values=[mix_loc, temperature]):
if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale "
"LinearOperators, one for each component with "
@ -473,8 +436,15 @@ class VectorDiffeomixture(distribution_lib.Distribution):
raise NotImplementedError("Currently only bimixtures are supported; "
"len(scale)={} is not 2.".format(len(scale)))
mix_loc = ops.convert_to_tensor(
mix_loc, dtype=dtype, name="mix_loc")
temperature = ops.convert_to_tensor(
temperature, dtype=dtype, name="temperature")
self._grid, probs = tuple(quadrature_fn(
mix_loc, mix_scale, quadrature_size, validate_args))
mix_loc / temperature,
1. / temperature,
quadrature_size,
validate_args))
# Note: by creating the logits as `log(prob)` we ensure that
# `self.mixture_distribution.logits` is equivalent to
@ -618,7 +588,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
weight = array_ops.gather(
array_ops.reshape(self.grid, shape=[-1]),
ids + offset)
weight = weight[..., array_ops.newaxis]
# At this point, weight flattened all batch dims into one.
# We also need to append a singleton to broadcast with event dims.
if self.batch_shape.is_fully_defined():
new_shape = [-1] + self.batch_shape.as_list() + [1]
else:
new_shape = array_ops.concat(
([-1], self.batch_shape_tensor(), [1]), axis=0)
weight = array_ops.reshape(weight, shape=new_shape)
if len(x) != 2:
# We actually should have already triggered this exception. However as a
@ -686,7 +663,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
# To compute E[Cov(Z|V)], we'll add matrices within three categories:
# scaled-identity, diagonal, and full. Then we'll combine these at the end.
scaled_identity = None
scale_identity_multiplier = None
diag = None
full = None
@ -694,10 +671,12 @@ class VectorDiffeomixture(distribution_lib.Distribution):
s = aff.scale # Just in case aff.scale has side-effects, we'll call once.
if (s is None
or isinstance(s, linop_identity_lib.LinearOperatorIdentity)):
scaled_identity = add(scaled_identity, p[..., k, array_ops.newaxis])
scale_identity_multiplier = add(scale_identity_multiplier,
p[..., k, array_ops.newaxis])
elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity):
scaled_identity = add(scaled_identity, (p[..., k, array_ops.newaxis] *
math_ops.square(s.multiplier)))
scale_identity_multiplier = add(
scale_identity_multiplier,
(p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier)))
elif isinstance(s, linop_diag_lib.LinearOperatorDiag):
diag = add(diag, (p[..., k, array_ops.newaxis] *
math_ops.square(s.diag_part())))
@ -709,12 +688,13 @@ class VectorDiffeomixture(distribution_lib.Distribution):
full = add(full, x)
# We must now account for the fact that the base distribution might have a
# non-unity variance. Recall that `Cov(SX+m) = S.T Cov(X) S = S.T S Var(X)`.
# non-unity variance. Recall that, since X ~ iid Law(X_0),
# `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`.
# We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid
# samples from a scalar-event distribution.
v = self.distribution.variance()
if scaled_identity is not None:
scaled_identity *= v
if scale_identity_multiplier is not None:
scale_identity_multiplier *= v
if diag is not None:
diag *= v[..., array_ops.newaxis]
if full is not None:
@ -723,10 +703,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
if diag_only:
# Apparently we don't need the full matrix, just the diagonal.
r = add(diag, full)
if r is None and scaled_identity is not None:
if r is None and scale_identity_multiplier is not None:
ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype)
return scaled_identity * ones
return add(r, scaled_identity)
return scale_identity_multiplier[..., array_ops.newaxis] * ones
return add(r, scale_identity_multiplier)
# `None` indicates we don't know if the result is positive-definite.
is_positive_definite = (True if all(aff.scale.is_positive_definite
@ -742,10 +722,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
to_add.append(linop_full_lib.LinearOperatorFullMatrix(
matrix=full,
is_positive_definite=is_positive_definite))
if scaled_identity is not None:
if scale_identity_multiplier is not None:
to_add.append(linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=self.event_shape_tensor()[0],
multiplier=scaled_identity,
multiplier=scale_identity_multiplier,
is_positive_definite=is_positive_definite))
return (linop_add_lib.add_operators(to_add)[0].to_dense()