Make tf.contrib.distributions quadrature family accept a Tensor for

`quadrature_grid_and_probs` argument.

PiperOrigin-RevId: 172950094
This commit is contained in:
Joshua V. Dillon 2017-10-20 16:28:55 -07:00 committed by TensorFlower Gardener
parent 8ff33271ea
commit 4948379369
10 changed files with 216 additions and 97 deletions

View File

@ -111,7 +111,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
event_shape=[dims],
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess=sess,
sess_run_fn=sess.run,
dist=dist,
num_samples=int(1e5),
radius=1.,
@ -130,7 +130,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
event_shape=[dims],
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess=sess,
sess_run_fn=sess.run,
dist=dist,
num_samples=int(1e5),
radius=1.,

View File

@ -23,7 +23,6 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
@ -41,8 +40,7 @@ def try_import(name): # pylint: disable=invalid-name
stats = try_import("scipy.stats")
class ProductDistributionTest(
test_util.VectorDistributionTestHelpers, test.TestCase):
class ProductDistributionTest(test.TestCase):
def testSampleAndLogProbUnivariate(self):
loc = np.float32([-1., 1])

View File

@ -94,10 +94,10 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
# Ball centered at component0's mean.
self.run_test_sample_consistent_log_prob(
sess, gm, radius=1., center=[-1., 1], rtol=0.02)
sess.run, gm, radius=1., center=[-1., 1], rtol=0.02)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess, gm, radius=1., center=[1., -1], rtol=0.02)
sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
def testLogCdf(self):
with self.test_session() as sess:
@ -122,7 +122,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
self.run_test_sample_consistent_mean_covariance(sess, gm)
self.run_test_sample_consistent_mean_covariance(sess.run, gm)
def testVarianceConsistentCovariance(self):
with self.test_session() as sess:

View File

@ -22,6 +22,8 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -38,7 +40,7 @@ class PoissonLogNormalQuadratureCompoundTest(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1)
sess.run, pln, rtol=0.1)
def testMeanVariance(self):
with self.test_session() as sess:
@ -49,7 +51,7 @@ class PoissonLogNormalQuadratureCompoundTest(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.02)
sess.run, pln, rtol=0.02)
def testSampleProbConsistentBroadcastScalar(self):
with self.test_session() as sess:
@ -60,7 +62,7 @@ class PoissonLogNormalQuadratureCompoundTest(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.01)
sess.run, pln, rtol=0.1, atol=0.01)
def testMeanVarianceBroadcastScalar(self):
with self.test_session() as sess:
@ -71,7 +73,7 @@ class PoissonLogNormalQuadratureCompoundTest(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
sess.run, pln, rtol=0.1, atol=0.01)
def testSampleProbConsistentBroadcastBoth(self):
with self.test_session() as sess:
@ -82,7 +84,7 @@ class PoissonLogNormalQuadratureCompoundTest(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.08)
sess.run, pln, rtol=0.1, atol=0.08)
def testMeanVarianceBroadcastBoth(self):
with self.test_session() as sess:
@ -93,7 +95,21 @@ class PoissonLogNormalQuadratureCompoundTest(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
sess.run, pln, rtol=0.1, atol=0.01)
def testSampleProbConsistentDynamicQuadrature(self):
with self.test_session() as sess:
qgrid = array_ops.placeholder(dtype=dtypes.float32)
qprobs = array_ops.placeholder(dtype=dtypes.float32)
g, p = np.polynomial.hermite.hermgauss(deg=10)
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=-2.,
scale=1.1,
quadrature_grid_and_probs=(g, p),
validate_args=True)
self.run_test_sample_consistent_log_prob(
lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}),
pln, rtol=0.1)
if __name__ == "__main__":

View File

@ -22,6 +22,8 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
@ -55,10 +57,10 @@ class VectorDiffeomixtureTest(
validate_args=True)
# Ball centered at component0's mean.
self.run_test_sample_consistent_log_prob(
sess, vdm, radius=2., center=0., rtol=0.005)
sess.run, vdm, radius=2., center=0., rtol=0.005)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess, vdm, radius=4., center=2., rtol=0.005)
sess.run, vdm, radius=4., center=2., rtol=0.005)
def testSampleProbConsistentBroadcastMixNonStandardBase(self):
with self.test_session() as sess:
@ -83,10 +85,10 @@ class VectorDiffeomixtureTest(
validate_args=True)
# Ball centered at component0's mean.
self.run_test_sample_consistent_log_prob(
sess, vdm, radius=2., center=1., rtol=0.006)
sess.run, vdm, radius=2., center=1., rtol=0.006)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess, vdm, radius=4., center=3., rtol=0.009)
sess.run, vdm, radius=4., center=3., rtol=0.009)
def testSampleProbConsistentBroadcastMixBatch(self):
with self.test_session() as sess:
@ -114,10 +116,10 @@ class VectorDiffeomixtureTest(
validate_args=True)
# Ball centered at component0's mean.
self.run_test_sample_consistent_log_prob(
sess, vdm, radius=2., center=0., rtol=0.005)
sess.run, vdm, radius=2., center=0., rtol=0.005)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess, vdm, radius=4., center=2., rtol=0.005)
sess.run, vdm, radius=4., center=2., rtol=0.005)
def testMeanCovarianceNoBatch(self):
with self.test_session() as sess:
@ -141,7 +143,7 @@ class VectorDiffeomixtureTest(
],
validate_args=True)
self.run_test_sample_consistent_mean_covariance(
sess, vdm, rtol=0.02, cov_rtol=0.06)
sess.run, vdm, rtol=0.02, cov_rtol=0.06)
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
with self.test_session() as sess:
@ -165,7 +167,7 @@ class VectorDiffeomixtureTest(
],
validate_args=True)
self.run_test_sample_consistent_mean_covariance(
sess, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
def testMeanCovarianceBatch(self):
with self.test_session() as sess:
@ -192,7 +194,40 @@ class VectorDiffeomixtureTest(
],
validate_args=True)
self.run_test_sample_consistent_mean_covariance(
sess, vdm, rtol=0.02, cov_rtol=0.06)
sess.run, vdm, rtol=0.02, cov_rtol=0.06)
def testSampleProbConsistentDynamicQuadrature(self):
with self.test_session() as sess:
qgrid = array_ops.placeholder(dtype=dtypes.float32)
qprobs = array_ops.placeholder(dtype=dtypes.float32)
g, p = np.polynomial.hermite.hermgauss(deg=8)
dims = 4
vdm = vector_diffeomixture_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
mix_scale=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
np.float32([2.]*dims),
],
scale=[
linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=np.float32(1.1),
is_positive_definite=True),
linop_diag_lib.LinearOperatorDiag(
diag=np.linspace(2.5, 3.5, dims, dtype=np.float32),
is_positive_definite=True),
],
quadrature_grid_and_probs=(g, p),
validate_args=True)
# Ball centered at component0's mean.
sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p})
self.run_test_sample_consistent_log_prob(
sess_run_fn, vdm, radius=2., center=0., rtol=0.005)
# Larger ball centered at component1's mean.
self.run_test_sample_consistent_log_prob(
sess_run_fn, vdm, radius=4., center=2., rtol=0.005)
# TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
# (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't

View File

@ -210,15 +210,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, sasnorm, radius=1.0, center=0., rtol=0.1)
sess.run, sasnorm, radius=1.0, center=0., rtol=0.1)
self.run_test_sample_consistent_log_prob(
sess,
sess.run,
sasnorm,
radius=1.0,
center=-0.15,
rtol=0.1)
self.run_test_sample_consistent_log_prob(
sess,
sess.run,
sasnorm,
radius=1.0,
center=0.15,
@ -237,15 +237,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, sasnorm, radius=1.0, center=0., rtol=0.1)
sess.run, sasnorm, radius=1.0, center=0., rtol=0.1)
self.run_test_sample_consistent_log_prob(
sess,
sess.run,
sasnorm,
radius=1.0,
center=-0.15,
rtol=0.1)
self.run_test_sample_consistent_log_prob(
sess,
sess.run,
sasnorm,
radius=1.0,
center=0.15,

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -29,7 +30,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import categorical as categorical_lib
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@ -55,8 +55,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
```
where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms
are from [Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that
are from [numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that
the second line made the substitution:
`z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above]
and `dl = sqrt(2) scale lambda(z) dz`
@ -65,8 +67,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
Poisson rate parameter. Unfortunately, the non-approximate distribution lacks
an analytical probability density function (pdf). Therefore the
`PoissonLogNormalQuadratureCompound` class implements an approximation based
on [Gauss-Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature).
on [numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)).
Note: although the `PoissonLogNormalQuadratureCompound` is approximately the
Poisson-LogNormal compound distribution, it is itself a valid distribution.
Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are
@ -76,9 +81,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal
[compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution).
Using variable-substitution and [Gauss-Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
https://en.wikipedia.org/wiki/Compound_probability_distribution). Using
variable-substitution and [numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
redefine the distribution to be a parameter-less convex combination of `deg`
different Poisson samples.
@ -125,9 +132,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
the LogNormal prior.
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
the LogNormal prior.
quadrature_grid_and_probs: Python pair of `list`-like objects representing
the sample points and the corresponding (possibly normalized) weight.
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
representing the sample points and the corresponding (possibly
normalized) weight. When `None`, defaults to:
`np.polynomial.hermite.hermgauss(deg=8)`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@ -140,8 +148,6 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
Raises:
TypeError: if `loc.dtype != scale[0].dtype`.
ValueError: if `quadrature_grid_and_probs is not None` and
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
"""
parameters = locals()
with ops.name_scope(name, values=[loc, scale]):
@ -157,21 +163,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
loc.dtype.name, scale.dtype.name))
if quadrature_grid_and_probs is None:
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
else:
grid, probs = tuple(quadrature_grid_and_probs)
if len(grid) != len(probs):
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length list-like objects")
grid = grid.astype(dtype.as_numpy_dtype)
probs = probs.astype(dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1)
grid, probs = distribution_util.process_quadrature_grid_and_probs(
quadrature_grid_and_probs, dtype, validate_args)
self._quadrature_grid = grid
self._quadrature_probs = probs
self._quadrature_size = distribution_util.dimension_size(probs, axis=0)
self._mixture_distribution = categorical_lib.Categorical(
logits=np.log(probs),
logits=math_ops.log(self._quadrature_probs),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
@ -254,10 +253,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
[batch_size])),
seed=distribution_util.gen_new_seed(
seed, "poisson_lognormal_quadrature_compound"))
# Stride `quadrature_degree` for `batch_size` number of times.
# Stride `quadrature_size` for `batch_size` number of times.
offset = math_ops.range(start=0,
limit=batch_size * len(self.quadrature_probs),
delta=len(self.quadrature_probs),
limit=batch_size * self._quadrature_size,
delta=self._quadrature_size,
dtype=ids.dtype)
ids += offset
rate = array_ops.gather(

View File

@ -38,7 +38,7 @@ class DiscreteScalarDistributionTestHelpers(object):
"""DiscreteScalarDistributionTestHelpers."""
def run_test_sample_consistent_log_prob(
self, sess, dist,
self, sess_run_fn, dist,
num_samples=int(1e5), num_threshold=int(1e3), seed=42,
rtol=1e-2, atol=0.):
"""Tests that sample/log_prob are consistent with each other.
@ -51,7 +51,9 @@ class DiscreteScalarDistributionTestHelpers(object):
are consistent.
Args:
sess: Tensorflow session.
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
returning a list of results after running one "step" of TensorFlow
computation, typically set to `sess.run`.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
@ -87,7 +89,7 @@ class DiscreteScalarDistributionTestHelpers(object):
probs = math_ops.exp(dist.log_prob(edges))
probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b]
[counts_, probs_] = sess.run([counts, probs])
[counts_, probs_] = sess_run_fn([counts, probs])
valid = counts_ > num_threshold
probs_ = probs_[valid]
counts_ = counts_[valid]
@ -95,7 +97,7 @@ class DiscreteScalarDistributionTestHelpers(object):
rtol=rtol, atol=atol)
def run_test_sample_consistent_mean_variance(
self, sess, dist,
self, sess_run_fn, dist,
num_samples=int(1e5), seed=24,
rtol=1e-2, atol=0.):
"""Tests that sample/mean/variance are consistent with each other.
@ -104,7 +106,9 @@ class DiscreteScalarDistributionTestHelpers(object):
to the same distribution.
Args:
sess: Tensorflow session.
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
returning a list of results after running one "step" of TensorFlow
computation, typically set to `sess.run`.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
@ -130,7 +134,7 @@ class DiscreteScalarDistributionTestHelpers(object):
mean_,
variance_,
stddev_
] = sess.run([
] = sess_run_fn([
sample_mean,
sample_variance,
sample_stddev,
@ -187,7 +191,7 @@ class VectorDistributionTestHelpers(object):
def run_test_sample_consistent_log_prob(
self,
sess,
sess_run_fn,
dist,
num_samples=int(1e5),
radius=1.,
@ -240,7 +244,9 @@ class VectorDistributionTestHelpers(object):
https://en.wikipedia.org/wiki/Importance_sampling.
Args:
sess: Tensorflow session.
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
returning a list of results after running one "step" of TensorFlow
computation, typically set to `sess.run`.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
distribution must have non-zero probability of sampling every point
@ -301,8 +307,8 @@ class VectorDistributionTestHelpers(object):
init_op = variables_ops.global_variables_initializer()
# Execute graph.
sess.run(init_op)
[batch_shape_, actual_volume_, sample_volume_] = sess.run([
sess_run_fn(init_op)
[batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([
batch_shape, actual_volume, sample_volume])
# Check results.
@ -312,7 +318,7 @@ class VectorDistributionTestHelpers(object):
def run_test_sample_consistent_mean_covariance(
self,
sess,
sess_run_fn,
dist,
num_samples=int(1e5),
seed=24,
@ -326,7 +332,9 @@ class VectorDistributionTestHelpers(object):
to the same distribution.
Args:
sess: Tensorflow session.
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
returning a list of results after running one "step" of TensorFlow
computation, typically set to `sess.run`.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
@ -360,7 +368,7 @@ class VectorDistributionTestHelpers(object):
covariance_,
variance_,
stddev_
] = sess.run([
] = sess_run_fn([
sample_mean,
sample_covariance,
sample_variance,

View File

@ -73,8 +73,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
denotes matrix multiplication. However, the non-approximate distribution does
not have an analytical probability density function (pdf). Therefore the
`VectorDiffeomixture` class implements an approximation based on
[Gauss-Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). I.e., in
[numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
Note: although the `VectorDiffeomixture` is approximately the
`SoftmaxNormal-Distribution` compound distribution, it is itself a valid
distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
@ -109,8 +111,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
[compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution).
Using variable-substitution and [Gauss-Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
Using variable-substitution and [numerical quadrature](
https://en.wikipedia.org/wiki/Numerical_integration) (default:
[Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
redefine the distribution to be a parameter-less convex combination of `K`
different affine combinations of a `d` iid samples from `distribution`.
@ -141,7 +145,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
and,
```none
grid, weight = np.polynomial.hermite.hermgauss(quadrature_degree)
grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
prob[k] = weight[k] / sqrt(pi)
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
```
@ -248,9 +252,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
`k`-th element represents the `scale` used for the `k`-th affine
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
quadrature_grid_and_probs: Python pair of `list`-like objects representing
the sample points and the corresponding (possibly normalized) weight.
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
representing the sample points and the corresponding (possibly
normalized) weight. When `None`, defaults to:
`np.polynomial.hermite.hermgauss(deg=8)`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@ -317,24 +322,17 @@ class VectorDiffeomixture(distribution_lib.Distribution):
raise NotImplementedError("Currently only bimixtures are supported; "
"len(scale)={} is not 2.".format(len(scale)))
if quadrature_grid_and_probs is None:
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
else:
grid, probs = tuple(quadrature_grid_and_probs)
if len(grid) != len(probs):
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length list-like objects")
grid = grid.astype(dtype.as_numpy_dtype)
probs = probs.astype(dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1)
grid, probs = distribution_util.process_quadrature_grid_and_probs(
quadrature_grid_and_probs, dtype, validate_args)
self._quadrature_grid = grid
self._quadrature_probs = probs
self._quadrature_size = distribution_util.dimension_size(probs, axis=0)
# Note: by creating the logits as `log(prob)` we ensure that
# `self.mixture_distribution.logits` is equivalent to
# `math_ops.log(self.mixture_distribution.probs)`.
self._mixture_distribution = categorical_lib.Categorical(
logits=np.log(probs),
logits=math_ops.log(probs),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
@ -361,10 +359,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
validate_args=validate_args,
name="interpolated_affine_{}".format(k))
for k, (loc_, scale_) in enumerate(zip(
interpolate_loc(len(self._quadrature_grid),
interpolate_loc(self._quadrature_size,
self._interpolate_weight,
loc),
interpolate_scale(len(self._quadrature_grid),
interpolate_scale(self._quadrature_size,
self._interpolate_weight,
scale)))]
@ -463,10 +461,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
seed=distribution_util.gen_new_seed(
seed, "vector_diffeomixture"))
# Stride `quadrature_degree` for `batch_size` number of times.
# Stride `quadrature_size` for `batch_size` number of times.
offset = math_ops.range(start=0,
limit=batch_size * len(self.quadrature_probs),
delta=len(self.quadrature_probs),
limit=batch_size * self._quadrature_size,
delta=self._quadrature_size,
dtype=ids.dtype)
weight = array_ops.gather(

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@ -1049,13 +1050,77 @@ def dimension_size(x, axis):
"""Returns the size of a specific dimension."""
# Since tf.gather isn't "constant-in, constant-out", we must first check the
# static shape or fallback to dynamic shape.
num_rows = (None if x.get_shape().ndims is None
else x.get_shape()[axis].value)
if num_rows is not None:
return num_rows
s = x.shape.with_rank_at_least(axis + 1)[axis].value
if axis > -1 and s is not None:
return s
return array_ops.shape(x)[axis]
def process_quadrature_grid_and_probs(
quadrature_grid_and_probs, dtype, validate_args, name=None):
"""Validates quadrature grid, probs or computes them as necessary.
Args:
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
representing the sample points and the corresponding (possibly
normalized) weight. When `None`, defaults to:
`np.polynomial.hermite.hermgauss(deg=8)`.
dtype: The expected `dtype` of `grid` and `probs`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
name: Python `str` name prefixed to Ops created by this class.
Returns:
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
representing the sample points and the corresponding (possibly
normalized) weight.
Raises:
ValueError: if `quadrature_grid_and_probs is not None` and
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
"""
with ops.name_scope(name, "process_quadrature_grid_and_probs",
[quadrature_grid_and_probs]):
if quadrature_grid_and_probs is None:
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
grid = grid.astype(dtype.as_numpy_dtype)
probs = probs.astype(dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
return grid, probs
grid, probs = tuple(quadrature_grid_and_probs)
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
dtype=dtype)
probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True,
name="probs")
def _static_dim_size(x, axis):
"""Returns the static size of a specific dimension or `None`."""
return x.shape.with_rank_at_least(axis + 1)[axis].value
m, n = _static_dim_size(probs, axis=0), _static_dim_size(grid, axis=0)
if m is not None and n is not None:
if m != n:
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length zero-th-dimension `Tensor`s "
"(saw lengths {}, {})".format(m, n))
elif validate_args:
grid = control_flow_ops.with_dependencies([
check_ops.assert_equal(
dimension_size(probs, axis=0),
dimension_size(grid, axis=0),
message=("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length zero-th-dimension `Tensor`s")),
], grid)
return grid, probs
class AppendDocstring(object):
"""Helper class to promote private subclass docstring to public counterpart.