Make tf.contrib.distributions
quadrature family accept a Tensor
for
`quadrature_grid_and_probs` argument. PiperOrigin-RevId: 172950094
This commit is contained in:
parent
8ff33271ea
commit
4948379369
tensorflow
contrib/distributions/python
kernel_tests
bijectors
independent_test.pymixture_same_family_test.pypoisson_lognormal_test.pyvector_diffeomixture_test.pyvector_sinh_arcsinh_diag_test.pyops
python/ops/distributions
@ -111,7 +111,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
|
||||
event_shape=[dims],
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess=sess,
|
||||
sess_run_fn=sess.run,
|
||||
dist=dist,
|
||||
num_samples=int(1e5),
|
||||
radius=1.,
|
||||
@ -130,7 +130,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
|
||||
event_shape=[dims],
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess=sess,
|
||||
sess_run_fn=sess.run,
|
||||
dist=dist,
|
||||
num_samples=int(1e5),
|
||||
radius=1.,
|
||||
|
@ -23,7 +23,6 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
|
||||
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
||||
from tensorflow.contrib.distributions.python.ops import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.platform import test
|
||||
@ -41,8 +40,7 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
class ProductDistributionTest(
|
||||
test_util.VectorDistributionTestHelpers, test.TestCase):
|
||||
class ProductDistributionTest(test.TestCase):
|
||||
|
||||
def testSampleAndLogProbUnivariate(self):
|
||||
loc = np.float32([-1., 1])
|
||||
|
@ -94,10 +94,10 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
|
||||
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
|
||||
# Ball centered at component0's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, gm, radius=1., center=[-1., 1], rtol=0.02)
|
||||
sess.run, gm, radius=1., center=[-1., 1], rtol=0.02)
|
||||
# Larger ball centered at component1's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, gm, radius=1., center=[1., -1], rtol=0.02)
|
||||
sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
|
||||
|
||||
def testLogCdf(self):
|
||||
with self.test_session() as sess:
|
||||
@ -122,7 +122,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
|
||||
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
|
||||
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
|
||||
self.run_test_sample_consistent_mean_covariance(sess, gm)
|
||||
self.run_test_sample_consistent_mean_covariance(sess.run, gm)
|
||||
|
||||
def testVarianceConsistentCovariance(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -22,6 +22,8 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
|
||||
from tensorflow.contrib.distributions.python.ops import test_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -38,7 +40,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
||||
np.polynomial.hermite.hermgauss(deg=10)),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, pln, rtol=0.1)
|
||||
sess.run, pln, rtol=0.1)
|
||||
|
||||
def testMeanVariance(self):
|
||||
with self.test_session() as sess:
|
||||
@ -49,7 +51,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
||||
np.polynomial.hermite.hermgauss(deg=10)),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_variance(
|
||||
sess, pln, rtol=0.02)
|
||||
sess.run, pln, rtol=0.02)
|
||||
|
||||
def testSampleProbConsistentBroadcastScalar(self):
|
||||
with self.test_session() as sess:
|
||||
@ -60,7 +62,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
||||
np.polynomial.hermite.hermgauss(deg=10)),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, pln, rtol=0.1, atol=0.01)
|
||||
sess.run, pln, rtol=0.1, atol=0.01)
|
||||
|
||||
def testMeanVarianceBroadcastScalar(self):
|
||||
with self.test_session() as sess:
|
||||
@ -71,7 +73,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
||||
np.polynomial.hermite.hermgauss(deg=10)),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_variance(
|
||||
sess, pln, rtol=0.1, atol=0.01)
|
||||
sess.run, pln, rtol=0.1, atol=0.01)
|
||||
|
||||
def testSampleProbConsistentBroadcastBoth(self):
|
||||
with self.test_session() as sess:
|
||||
@ -82,7 +84,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
||||
np.polynomial.hermite.hermgauss(deg=10)),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, pln, rtol=0.1, atol=0.08)
|
||||
sess.run, pln, rtol=0.1, atol=0.08)
|
||||
|
||||
def testMeanVarianceBroadcastBoth(self):
|
||||
with self.test_session() as sess:
|
||||
@ -93,7 +95,21 @@ class PoissonLogNormalQuadratureCompoundTest(
|
||||
np.polynomial.hermite.hermgauss(deg=10)),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_variance(
|
||||
sess, pln, rtol=0.1, atol=0.01)
|
||||
sess.run, pln, rtol=0.1, atol=0.01)
|
||||
|
||||
def testSampleProbConsistentDynamicQuadrature(self):
|
||||
with self.test_session() as sess:
|
||||
qgrid = array_ops.placeholder(dtype=dtypes.float32)
|
||||
qprobs = array_ops.placeholder(dtype=dtypes.float32)
|
||||
g, p = np.polynomial.hermite.hermgauss(deg=10)
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=-2.,
|
||||
scale=1.1,
|
||||
quadrature_grid_and_probs=(g, p),
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}),
|
||||
pln, rtol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -22,6 +22,8 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import test_util
|
||||
from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
|
||||
@ -55,10 +57,10 @@ class VectorDiffeomixtureTest(
|
||||
validate_args=True)
|
||||
# Ball centered at component0's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, vdm, radius=2., center=0., rtol=0.005)
|
||||
sess.run, vdm, radius=2., center=0., rtol=0.005)
|
||||
# Larger ball centered at component1's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, vdm, radius=4., center=2., rtol=0.005)
|
||||
sess.run, vdm, radius=4., center=2., rtol=0.005)
|
||||
|
||||
def testSampleProbConsistentBroadcastMixNonStandardBase(self):
|
||||
with self.test_session() as sess:
|
||||
@ -83,10 +85,10 @@ class VectorDiffeomixtureTest(
|
||||
validate_args=True)
|
||||
# Ball centered at component0's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, vdm, radius=2., center=1., rtol=0.006)
|
||||
sess.run, vdm, radius=2., center=1., rtol=0.006)
|
||||
# Larger ball centered at component1's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, vdm, radius=4., center=3., rtol=0.009)
|
||||
sess.run, vdm, radius=4., center=3., rtol=0.009)
|
||||
|
||||
def testSampleProbConsistentBroadcastMixBatch(self):
|
||||
with self.test_session() as sess:
|
||||
@ -114,10 +116,10 @@ class VectorDiffeomixtureTest(
|
||||
validate_args=True)
|
||||
# Ball centered at component0's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, vdm, radius=2., center=0., rtol=0.005)
|
||||
sess.run, vdm, radius=2., center=0., rtol=0.005)
|
||||
# Larger ball centered at component1's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, vdm, radius=4., center=2., rtol=0.005)
|
||||
sess.run, vdm, radius=4., center=2., rtol=0.005)
|
||||
|
||||
def testMeanCovarianceNoBatch(self):
|
||||
with self.test_session() as sess:
|
||||
@ -141,7 +143,7 @@ class VectorDiffeomixtureTest(
|
||||
],
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_covariance(
|
||||
sess, vdm, rtol=0.02, cov_rtol=0.06)
|
||||
sess.run, vdm, rtol=0.02, cov_rtol=0.06)
|
||||
|
||||
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
|
||||
with self.test_session() as sess:
|
||||
@ -165,7 +167,7 @@ class VectorDiffeomixtureTest(
|
||||
],
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_covariance(
|
||||
sess, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
|
||||
sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
|
||||
|
||||
def testMeanCovarianceBatch(self):
|
||||
with self.test_session() as sess:
|
||||
@ -192,7 +194,40 @@ class VectorDiffeomixtureTest(
|
||||
],
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_covariance(
|
||||
sess, vdm, rtol=0.02, cov_rtol=0.06)
|
||||
sess.run, vdm, rtol=0.02, cov_rtol=0.06)
|
||||
|
||||
def testSampleProbConsistentDynamicQuadrature(self):
|
||||
with self.test_session() as sess:
|
||||
qgrid = array_ops.placeholder(dtype=dtypes.float32)
|
||||
qprobs = array_ops.placeholder(dtype=dtypes.float32)
|
||||
g, p = np.polynomial.hermite.hermgauss(deg=8)
|
||||
dims = 4
|
||||
vdm = vector_diffeomixture_lib.VectorDiffeomixture(
|
||||
mix_loc=[[0.], [1.]],
|
||||
mix_scale=[1.],
|
||||
distribution=normal_lib.Normal(0., 1.),
|
||||
loc=[
|
||||
None,
|
||||
np.float32([2.]*dims),
|
||||
],
|
||||
scale=[
|
||||
linop_identity_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=dims,
|
||||
multiplier=np.float32(1.1),
|
||||
is_positive_definite=True),
|
||||
linop_diag_lib.LinearOperatorDiag(
|
||||
diag=np.linspace(2.5, 3.5, dims, dtype=np.float32),
|
||||
is_positive_definite=True),
|
||||
],
|
||||
quadrature_grid_and_probs=(g, p),
|
||||
validate_args=True)
|
||||
# Ball centered at component0's mean.
|
||||
sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p})
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess_run_fn, vdm, radius=2., center=0., rtol=0.005)
|
||||
# Larger ball centered at component1's mean.
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess_run_fn, vdm, radius=4., center=2., rtol=0.005)
|
||||
|
||||
# TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
|
||||
# (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
|
||||
|
@ -210,15 +210,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
|
||||
validate_args=True)
|
||||
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, sasnorm, radius=1.0, center=0., rtol=0.1)
|
||||
sess.run, sasnorm, radius=1.0, center=0., rtol=0.1)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess,
|
||||
sess.run,
|
||||
sasnorm,
|
||||
radius=1.0,
|
||||
center=-0.15,
|
||||
rtol=0.1)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess,
|
||||
sess.run,
|
||||
sasnorm,
|
||||
radius=1.0,
|
||||
center=0.15,
|
||||
@ -237,15 +237,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
|
||||
validate_args=True)
|
||||
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, sasnorm, radius=1.0, center=0., rtol=0.1)
|
||||
sess.run, sasnorm, radius=1.0, center=0., rtol=0.1)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess,
|
||||
sess.run,
|
||||
sasnorm,
|
||||
radius=1.0,
|
||||
center=-0.15,
|
||||
rtol=0.1)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess,
|
||||
sess.run,
|
||||
sasnorm,
|
||||
radius=1.0,
|
||||
center=0.15,
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||
from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -29,7 +30,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops.distributions import categorical as categorical_lib
|
||||
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
||||
from tensorflow.python.ops.distributions import util as distribution_util
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -55,8 +55,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
```
|
||||
|
||||
where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms
|
||||
are from [Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that
|
||||
are from [numerical quadrature](
|
||||
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||
[Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that
|
||||
the second line made the substitution:
|
||||
`z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above]
|
||||
and `dl = sqrt(2) scale lambda(z) dz`
|
||||
@ -65,8 +67,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
Poisson rate parameter. Unfortunately, the non-approximate distribution lacks
|
||||
an analytical probability density function (pdf). Therefore the
|
||||
`PoissonLogNormalQuadratureCompound` class implements an approximation based
|
||||
on [Gauss-Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature).
|
||||
on [numerical quadrature](
|
||||
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||
[Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)).
|
||||
|
||||
Note: although the `PoissonLogNormalQuadratureCompound` is approximately the
|
||||
Poisson-LogNormal compound distribution, it is itself a valid distribution.
|
||||
Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are
|
||||
@ -76,9 +81,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
|
||||
The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal
|
||||
[compound distribution](
|
||||
https://en.wikipedia.org/wiki/Compound_probability_distribution).
|
||||
Using variable-substitution and [Gauss-Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
|
||||
https://en.wikipedia.org/wiki/Compound_probability_distribution). Using
|
||||
variable-substitution and [numerical quadrature](
|
||||
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||
[Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
|
||||
redefine the distribution to be a parameter-less convex combination of `deg`
|
||||
different Poisson samples.
|
||||
|
||||
@ -125,9 +132,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
the LogNormal prior.
|
||||
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
|
||||
the LogNormal prior.
|
||||
quadrature_grid_and_probs: Python pair of `list`-like objects representing
|
||||
the sample points and the corresponding (possibly normalized) weight.
|
||||
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
|
||||
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||
representing the sample points and the corresponding (possibly
|
||||
normalized) weight. When `None`, defaults to:
|
||||
`np.polynomial.hermite.hermgauss(deg=8)`.
|
||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||
parameters are checked for validity despite possibly degrading runtime
|
||||
performance. When `False` invalid inputs may silently render incorrect
|
||||
@ -140,8 +148,6 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
|
||||
Raises:
|
||||
TypeError: if `loc.dtype != scale[0].dtype`.
|
||||
ValueError: if `quadrature_grid_and_probs is not None` and
|
||||
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
@ -157,21 +163,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
|
||||
loc.dtype.name, scale.dtype.name))
|
||||
|
||||
if quadrature_grid_and_probs is None:
|
||||
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
||||
else:
|
||||
grid, probs = tuple(quadrature_grid_and_probs)
|
||||
if len(grid) != len(probs):
|
||||
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
||||
"same-length list-like objects")
|
||||
grid = grid.astype(dtype.as_numpy_dtype)
|
||||
probs = probs.astype(dtype.as_numpy_dtype)
|
||||
probs /= np.linalg.norm(probs, ord=1)
|
||||
grid, probs = distribution_util.process_quadrature_grid_and_probs(
|
||||
quadrature_grid_and_probs, dtype, validate_args)
|
||||
self._quadrature_grid = grid
|
||||
self._quadrature_probs = probs
|
||||
self._quadrature_size = distribution_util.dimension_size(probs, axis=0)
|
||||
|
||||
self._mixture_distribution = categorical_lib.Categorical(
|
||||
logits=np.log(probs),
|
||||
logits=math_ops.log(self._quadrature_probs),
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats)
|
||||
|
||||
@ -254,10 +253,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
[batch_size])),
|
||||
seed=distribution_util.gen_new_seed(
|
||||
seed, "poisson_lognormal_quadrature_compound"))
|
||||
# Stride `quadrature_degree` for `batch_size` number of times.
|
||||
# Stride `quadrature_size` for `batch_size` number of times.
|
||||
offset = math_ops.range(start=0,
|
||||
limit=batch_size * len(self.quadrature_probs),
|
||||
delta=len(self.quadrature_probs),
|
||||
limit=batch_size * self._quadrature_size,
|
||||
delta=self._quadrature_size,
|
||||
dtype=ids.dtype)
|
||||
ids += offset
|
||||
rate = array_ops.gather(
|
||||
|
@ -38,7 +38,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
||||
"""DiscreteScalarDistributionTestHelpers."""
|
||||
|
||||
def run_test_sample_consistent_log_prob(
|
||||
self, sess, dist,
|
||||
self, sess_run_fn, dist,
|
||||
num_samples=int(1e5), num_threshold=int(1e3), seed=42,
|
||||
rtol=1e-2, atol=0.):
|
||||
"""Tests that sample/log_prob are consistent with each other.
|
||||
@ -51,7 +51,9 @@ class DiscreteScalarDistributionTestHelpers(object):
|
||||
are consistent.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||
returning a list of results after running one "step" of TensorFlow
|
||||
computation, typically set to `sess.run`.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
@ -87,7 +89,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
||||
probs = math_ops.exp(dist.log_prob(edges))
|
||||
probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b]
|
||||
|
||||
[counts_, probs_] = sess.run([counts, probs])
|
||||
[counts_, probs_] = sess_run_fn([counts, probs])
|
||||
valid = counts_ > num_threshold
|
||||
probs_ = probs_[valid]
|
||||
counts_ = counts_[valid]
|
||||
@ -95,7 +97,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
def run_test_sample_consistent_mean_variance(
|
||||
self, sess, dist,
|
||||
self, sess_run_fn, dist,
|
||||
num_samples=int(1e5), seed=24,
|
||||
rtol=1e-2, atol=0.):
|
||||
"""Tests that sample/mean/variance are consistent with each other.
|
||||
@ -104,7 +106,9 @@ class DiscreteScalarDistributionTestHelpers(object):
|
||||
to the same distribution.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||
returning a list of results after running one "step" of TensorFlow
|
||||
computation, typically set to `sess.run`.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
@ -130,7 +134,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
||||
mean_,
|
||||
variance_,
|
||||
stddev_
|
||||
] = sess.run([
|
||||
] = sess_run_fn([
|
||||
sample_mean,
|
||||
sample_variance,
|
||||
sample_stddev,
|
||||
@ -187,7 +191,7 @@ class VectorDistributionTestHelpers(object):
|
||||
|
||||
def run_test_sample_consistent_log_prob(
|
||||
self,
|
||||
sess,
|
||||
sess_run_fn,
|
||||
dist,
|
||||
num_samples=int(1e5),
|
||||
radius=1.,
|
||||
@ -240,7 +244,9 @@ class VectorDistributionTestHelpers(object):
|
||||
https://en.wikipedia.org/wiki/Importance_sampling.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||
returning a list of results after running one "step" of TensorFlow
|
||||
computation, typically set to `sess.run`.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
|
||||
distribution must have non-zero probability of sampling every point
|
||||
@ -301,8 +307,8 @@ class VectorDistributionTestHelpers(object):
|
||||
init_op = variables_ops.global_variables_initializer()
|
||||
|
||||
# Execute graph.
|
||||
sess.run(init_op)
|
||||
[batch_shape_, actual_volume_, sample_volume_] = sess.run([
|
||||
sess_run_fn(init_op)
|
||||
[batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([
|
||||
batch_shape, actual_volume, sample_volume])
|
||||
|
||||
# Check results.
|
||||
@ -312,7 +318,7 @@ class VectorDistributionTestHelpers(object):
|
||||
|
||||
def run_test_sample_consistent_mean_covariance(
|
||||
self,
|
||||
sess,
|
||||
sess_run_fn,
|
||||
dist,
|
||||
num_samples=int(1e5),
|
||||
seed=24,
|
||||
@ -326,7 +332,9 @@ class VectorDistributionTestHelpers(object):
|
||||
to the same distribution.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||
returning a list of results after running one "step" of TensorFlow
|
||||
computation, typically set to `sess.run`.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
@ -360,7 +368,7 @@ class VectorDistributionTestHelpers(object):
|
||||
covariance_,
|
||||
variance_,
|
||||
stddev_
|
||||
] = sess.run([
|
||||
] = sess_run_fn([
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
sample_variance,
|
||||
|
@ -73,8 +73,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
denotes matrix multiplication. However, the non-approximate distribution does
|
||||
not have an analytical probability density function (pdf). Therefore the
|
||||
`VectorDiffeomixture` class implements an approximation based on
|
||||
[Gauss-Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). I.e., in
|
||||
[numerical quadrature](
|
||||
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||
[Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
|
||||
Note: although the `VectorDiffeomixture` is approximately the
|
||||
`SoftmaxNormal-Distribution` compound distribution, it is itself a valid
|
||||
distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
|
||||
@ -109,8 +111,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
|
||||
[compound distribution](
|
||||
https://en.wikipedia.org/wiki/Compound_probability_distribution).
|
||||
Using variable-substitution and [Gauss-Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
|
||||
Using variable-substitution and [numerical quadrature](
|
||||
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||
[Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
|
||||
redefine the distribution to be a parameter-less convex combination of `K`
|
||||
different affine combinations of a `d` iid samples from `distribution`.
|
||||
|
||||
@ -141,7 +145,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
and,
|
||||
|
||||
```none
|
||||
grid, weight = np.polynomial.hermite.hermgauss(quadrature_degree)
|
||||
grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
|
||||
prob[k] = weight[k] / sqrt(pi)
|
||||
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
|
||||
```
|
||||
@ -248,9 +252,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
`k`-th element represents the `scale` used for the `k`-th affine
|
||||
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
|
||||
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
|
||||
quadrature_grid_and_probs: Python pair of `list`-like objects representing
|
||||
the sample points and the corresponding (possibly normalized) weight.
|
||||
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
|
||||
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||
representing the sample points and the corresponding (possibly
|
||||
normalized) weight. When `None`, defaults to:
|
||||
`np.polynomial.hermite.hermgauss(deg=8)`.
|
||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||
parameters are checked for validity despite possibly degrading runtime
|
||||
performance. When `False` invalid inputs may silently render incorrect
|
||||
@ -317,24 +322,17 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
raise NotImplementedError("Currently only bimixtures are supported; "
|
||||
"len(scale)={} is not 2.".format(len(scale)))
|
||||
|
||||
if quadrature_grid_and_probs is None:
|
||||
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
||||
else:
|
||||
grid, probs = tuple(quadrature_grid_and_probs)
|
||||
if len(grid) != len(probs):
|
||||
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
||||
"same-length list-like objects")
|
||||
grid = grid.astype(dtype.as_numpy_dtype)
|
||||
probs = probs.astype(dtype.as_numpy_dtype)
|
||||
probs /= np.linalg.norm(probs, ord=1)
|
||||
grid, probs = distribution_util.process_quadrature_grid_and_probs(
|
||||
quadrature_grid_and_probs, dtype, validate_args)
|
||||
self._quadrature_grid = grid
|
||||
self._quadrature_probs = probs
|
||||
self._quadrature_size = distribution_util.dimension_size(probs, axis=0)
|
||||
|
||||
# Note: by creating the logits as `log(prob)` we ensure that
|
||||
# `self.mixture_distribution.logits` is equivalent to
|
||||
# `math_ops.log(self.mixture_distribution.probs)`.
|
||||
self._mixture_distribution = categorical_lib.Categorical(
|
||||
logits=np.log(probs),
|
||||
logits=math_ops.log(probs),
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats)
|
||||
|
||||
@ -361,10 +359,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
validate_args=validate_args,
|
||||
name="interpolated_affine_{}".format(k))
|
||||
for k, (loc_, scale_) in enumerate(zip(
|
||||
interpolate_loc(len(self._quadrature_grid),
|
||||
interpolate_loc(self._quadrature_size,
|
||||
self._interpolate_weight,
|
||||
loc),
|
||||
interpolate_scale(len(self._quadrature_grid),
|
||||
interpolate_scale(self._quadrature_size,
|
||||
self._interpolate_weight,
|
||||
scale)))]
|
||||
|
||||
@ -463,10 +461,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
seed=distribution_util.gen_new_seed(
|
||||
seed, "vector_diffeomixture"))
|
||||
|
||||
# Stride `quadrature_degree` for `batch_size` number of times.
|
||||
# Stride `quadrature_size` for `batch_size` number of times.
|
||||
offset = math_ops.range(start=0,
|
||||
limit=batch_size * len(self.quadrature_probs),
|
||||
delta=len(self.quadrature_probs),
|
||||
limit=batch_size * self._quadrature_size,
|
||||
delta=self._quadrature_size,
|
||||
dtype=ids.dtype)
|
||||
|
||||
weight = array_ops.gather(
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
@ -1049,13 +1050,77 @@ def dimension_size(x, axis):
|
||||
"""Returns the size of a specific dimension."""
|
||||
# Since tf.gather isn't "constant-in, constant-out", we must first check the
|
||||
# static shape or fallback to dynamic shape.
|
||||
num_rows = (None if x.get_shape().ndims is None
|
||||
else x.get_shape()[axis].value)
|
||||
if num_rows is not None:
|
||||
return num_rows
|
||||
s = x.shape.with_rank_at_least(axis + 1)[axis].value
|
||||
if axis > -1 and s is not None:
|
||||
return s
|
||||
return array_ops.shape(x)[axis]
|
||||
|
||||
|
||||
def process_quadrature_grid_and_probs(
|
||||
quadrature_grid_and_probs, dtype, validate_args, name=None):
|
||||
"""Validates quadrature grid, probs or computes them as necessary.
|
||||
|
||||
Args:
|
||||
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||
representing the sample points and the corresponding (possibly
|
||||
normalized) weight. When `None`, defaults to:
|
||||
`np.polynomial.hermite.hermgauss(deg=8)`.
|
||||
dtype: The expected `dtype` of `grid` and `probs`.
|
||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||
parameters are checked for validity despite possibly degrading runtime
|
||||
performance. When `False` invalid inputs may silently render incorrect
|
||||
outputs.
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
|
||||
Returns:
|
||||
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||
representing the sample points and the corresponding (possibly
|
||||
normalized) weight.
|
||||
|
||||
Raises:
|
||||
ValueError: if `quadrature_grid_and_probs is not None` and
|
||||
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
|
||||
"""
|
||||
with ops.name_scope(name, "process_quadrature_grid_and_probs",
|
||||
[quadrature_grid_and_probs]):
|
||||
if quadrature_grid_and_probs is None:
|
||||
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
||||
grid = grid.astype(dtype.as_numpy_dtype)
|
||||
probs = probs.astype(dtype.as_numpy_dtype)
|
||||
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
|
||||
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
|
||||
probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
|
||||
return grid, probs
|
||||
|
||||
grid, probs = tuple(quadrature_grid_and_probs)
|
||||
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
|
||||
probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
|
||||
dtype=dtype)
|
||||
probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True,
|
||||
name="probs")
|
||||
|
||||
def _static_dim_size(x, axis):
|
||||
"""Returns the static size of a specific dimension or `None`."""
|
||||
return x.shape.with_rank_at_least(axis + 1)[axis].value
|
||||
|
||||
m, n = _static_dim_size(probs, axis=0), _static_dim_size(grid, axis=0)
|
||||
if m is not None and n is not None:
|
||||
if m != n:
|
||||
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
||||
"same-length zero-th-dimension `Tensor`s "
|
||||
"(saw lengths {}, {})".format(m, n))
|
||||
elif validate_args:
|
||||
grid = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_equal(
|
||||
dimension_size(probs, axis=0),
|
||||
dimension_size(grid, axis=0),
|
||||
message=("`quadrature_grid_and_probs` must be a `tuple` of "
|
||||
"same-length zero-th-dimension `Tensor`s")),
|
||||
], grid)
|
||||
|
||||
return grid, probs
|
||||
|
||||
|
||||
class AppendDocstring(object):
|
||||
"""Helper class to promote private subclass docstring to public counterpart.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user