diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 98c09545ac7..25a9b6f5fe2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -111,7 +111,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob( - sess=sess, + sess_run_fn=sess.run, dist=dist, num_samples=int(1e5), radius=1., @@ -130,7 +130,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob( - sess=sess, + sess_run_fn=sess.run, dist=dist, num_samples=int(1e5), radius=1., diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py index 7a321db4b29..dcc66e89720 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test @@ -41,8 +40,7 @@ def try_import(name): # pylint: disable=invalid-name stats = try_import("scipy.stats") -class ProductDistributionTest( - test_util.VectorDistributionTestHelpers, test.TestCase): +class ProductDistributionTest(test.TestCase): def testSampleAndLogProbUnivariate(self): loc = np.float32([-1., 1]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index ee4f989dac0..ece6bc077d9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -94,10 +94,10 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, gm, radius=1., center=[-1., 1], rtol=0.02) + sess.run, gm, radius=1., center=[-1., 1], rtol=0.02) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, gm, radius=1., center=[1., -1], rtol=0.02) + sess.run, gm, radius=1., center=[1., -1], rtol=0.02) def testLogCdf(self): with self.test_session() as sess: @@ -122,7 +122,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), components_distribution=mvn_diag_lib.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) - self.run_test_sample_consistent_mean_covariance(sess, gm) + self.run_test_sample_consistent_mean_covariance(sess.run, gm) def testVarianceConsistentCovariance(self): with self.test_session() as sess: diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 3ded4159d86..3c0147b8cf6 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -38,7 +40,7 @@ class PoissonLogNormalQuadratureCompoundTest( np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1) + sess.run, pln, rtol=0.1) def testMeanVariance(self): with self.test_session() as sess: @@ -49,7 +51,7 @@ class PoissonLogNormalQuadratureCompoundTest( np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.02) + sess.run, pln, rtol=0.02) def testSampleProbConsistentBroadcastScalar(self): with self.test_session() as sess: @@ -60,7 +62,7 @@ class PoissonLogNormalQuadratureCompoundTest( np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) def testMeanVarianceBroadcastScalar(self): with self.test_session() as sess: @@ -71,7 +73,7 @@ class PoissonLogNormalQuadratureCompoundTest( np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) def testSampleProbConsistentBroadcastBoth(self): with self.test_session() as sess: @@ -82,7 +84,7 @@ class PoissonLogNormalQuadratureCompoundTest( np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1, atol=0.08) + sess.run, pln, rtol=0.1, atol=0.08) def testMeanVarianceBroadcastBoth(self): with self.test_session() as sess: @@ -93,7 +95,21 @@ class PoissonLogNormalQuadratureCompoundTest( np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) + + def testSampleProbConsistentDynamicQuadrature(self): + with self.test_session() as sess: + qgrid = array_ops.placeholder(dtype=dtypes.float32) + qprobs = array_ops.placeholder(dtype=dtypes.float32) + g, p = np.polynomial.hermite.hermgauss(deg=10) + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=-2., + scale=1.1, + quadrature_grid_and_probs=(g, p), + validate_args=True) + self.run_test_sample_consistent_log_prob( + lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}), + pln, rtol=0.1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index aea4d425038..de4a221f7ba 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib @@ -55,10 +57,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.005) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) def testSampleProbConsistentBroadcastMixNonStandardBase(self): with self.test_session() as sess: @@ -83,10 +85,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=1., rtol=0.006) + sess.run, vdm, radius=2., center=1., rtol=0.006) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=3., rtol=0.009) + sess.run, vdm, radius=4., center=3., rtol=0.009) def testSampleProbConsistentBroadcastMixBatch(self): with self.test_session() as sess: @@ -114,10 +116,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.005) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) def testMeanCovarianceNoBatch(self): with self.test_session() as sess: @@ -141,7 +143,7 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.06) def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: @@ -165,7 +167,7 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) + sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) def testMeanCovarianceBatch(self): with self.test_session() as sess: @@ -192,7 +194,40 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.06) + + def testSampleProbConsistentDynamicQuadrature(self): + with self.test_session() as sess: + qgrid = array_ops.placeholder(dtype=dtypes.float32) + qprobs = array_ops.placeholder(dtype=dtypes.float32) + g, p = np.polynomial.hermite.hermgauss(deg=8) + dims = 4 + vdm = vector_diffeomixture_lib.VectorDiffeomixture( + mix_loc=[[0.], [1.]], + mix_scale=[1.], + distribution=normal_lib.Normal(0., 1.), + loc=[ + None, + np.float32([2.]*dims), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(1.1), + is_positive_definite=True), + linop_diag_lib.LinearOperatorDiag( + diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), + is_positive_definite=True), + ], + quadrature_grid_and_probs=(g, p), + validate_args=True) + # Ball centered at component0's mean. + sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}) + self.run_test_sample_consistent_log_prob( + sess_run_fn, vdm, radius=2., center=0., rtol=0.005) + # Larger ball centered at component1's mean. + self.run_test_sample_consistent_log_prob( + sess_run_fn, vdm, radius=4., center=2., rtol=0.005) # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py index a5d837d4541..2bc6a926dd6 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py @@ -210,15 +210,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers, validate_args=True) self.run_test_sample_consistent_log_prob( - sess, sasnorm, radius=1.0, center=0., rtol=0.1) + sess.run, sasnorm, radius=1.0, center=0., rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=-0.15, rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=0.15, @@ -237,15 +237,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers, validate_args=True) self.run_test_sample_consistent_log_prob( - sess, sasnorm, radius=1.0, center=0., rtol=0.1) + sess.run, sasnorm, radius=1.0, center=0., rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=-0.15, rtol=0.1) self.run_test_sample_consistent_log_prob( - sess, + sess.run, sasnorm, radius=1.0, center=0.15, diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 80d4e2dc5ef..8a95038a3c8 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,7 +30,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -55,8 +55,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ``` where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms - are from [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that + are from [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that the second line made the substitution: `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] and `dl = sqrt(2) scale lambda(z) dz` @@ -65,8 +67,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): Poisson rate parameter. Unfortunately, the non-approximate distribution lacks an analytical probability density function (pdf). Therefore the `PoissonLogNormalQuadratureCompound` class implements an approximation based - on [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). + on [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). + Note: although the `PoissonLogNormalQuadratureCompound` is approximately the Poisson-LogNormal compound distribution, it is itself a valid distribution. Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are @@ -76,9 +81,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + https://en.wikipedia.org/wiki/Compound_probability_distribution). Using + variable-substitution and [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can redefine the distribution to be a parameter-less convex combination of `deg` different Poisson samples. @@ -125,9 +132,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_grid_and_probs: Python pair of `list`-like objects representing - the sample points and the corresponding (possibly normalized) weight. - When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -140,8 +148,6 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): Raises: TypeError: if `loc.dtype != scale[0].dtype`. - ValueError: if `quadrature_grid_and_probs is not None` and - `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): @@ -157,21 +163,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( loc.dtype.name, scale.dtype.name)) - if quadrature_grid_and_probs is None: - grid, probs = np.polynomial.hermite.hermgauss(deg=8) - else: - grid, probs = tuple(quadrature_grid_and_probs) - if len(grid) != len(probs): - raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " - "same-length list-like objects") - grid = grid.astype(dtype.as_numpy_dtype) - probs = probs.astype(dtype.as_numpy_dtype) - probs /= np.linalg.norm(probs, ord=1) + grid, probs = distribution_util.process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args) self._quadrature_grid = grid self._quadrature_probs = probs + self._quadrature_size = distribution_util.dimension_size(probs, axis=0) self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(probs), + logits=math_ops.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -254,10 +253,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) - # Stride `quadrature_degree` for `batch_size` number of times. + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * len(self.quadrature_probs), - delta=len(self.quadrature_probs), + limit=batch_size * self._quadrature_size, + delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index 631ffc1bacd..77f2a39273d 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -38,7 +38,7 @@ class DiscreteScalarDistributionTestHelpers(object): """DiscreteScalarDistributionTestHelpers.""" def run_test_sample_consistent_log_prob( - self, sess, dist, + self, sess_run_fn, dist, num_samples=int(1e5), num_threshold=int(1e3), seed=42, rtol=1e-2, atol=0.): """Tests that sample/log_prob are consistent with each other. @@ -51,7 +51,9 @@ class DiscreteScalarDistributionTestHelpers(object): are consistent. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -87,7 +89,7 @@ class DiscreteScalarDistributionTestHelpers(object): probs = math_ops.exp(dist.log_prob(edges)) probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b] - [counts_, probs_] = sess.run([counts, probs]) + [counts_, probs_] = sess_run_fn([counts, probs]) valid = counts_ > num_threshold probs_ = probs_[valid] counts_ = counts_[valid] @@ -95,7 +97,7 @@ class DiscreteScalarDistributionTestHelpers(object): rtol=rtol, atol=atol) def run_test_sample_consistent_mean_variance( - self, sess, dist, + self, sess_run_fn, dist, num_samples=int(1e5), seed=24, rtol=1e-2, atol=0.): """Tests that sample/mean/variance are consistent with each other. @@ -104,7 +106,9 @@ class DiscreteScalarDistributionTestHelpers(object): to the same distribution. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -130,7 +134,7 @@ class DiscreteScalarDistributionTestHelpers(object): mean_, variance_, stddev_ - ] = sess.run([ + ] = sess_run_fn([ sample_mean, sample_variance, sample_stddev, @@ -187,7 +191,7 @@ class VectorDistributionTestHelpers(object): def run_test_sample_consistent_log_prob( self, - sess, + sess_run_fn, dist, num_samples=int(1e5), radius=1., @@ -240,7 +244,9 @@ class VectorDistributionTestHelpers(object): https://en.wikipedia.org/wiki/Importance_sampling. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The distribution must have non-zero probability of sampling every point @@ -301,8 +307,8 @@ class VectorDistributionTestHelpers(object): init_op = variables_ops.global_variables_initializer() # Execute graph. - sess.run(init_op) - [batch_shape_, actual_volume_, sample_volume_] = sess.run([ + sess_run_fn(init_op) + [batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([ batch_shape, actual_volume, sample_volume]) # Check results. @@ -312,7 +318,7 @@ class VectorDistributionTestHelpers(object): def run_test_sample_consistent_mean_covariance( self, - sess, + sess_run_fn, dist, num_samples=int(1e5), seed=24, @@ -326,7 +332,9 @@ class VectorDistributionTestHelpers(object): to the same distribution. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -360,7 +368,7 @@ class VectorDistributionTestHelpers(object): covariance_, variance_, stddev_ - ] = sess.run([ + ] = sess_run_fn([ sample_mean, sample_covariance, sample_variance, diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 33dad811a90..92043d6a088 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -73,8 +73,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): denotes matrix multiplication. However, the non-approximate distribution does not have an analytical probability density function (pdf). Therefore the `VectorDiffeomixture` class implements an approximation based on - [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). I.e., in + [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in Note: although the `VectorDiffeomixture` is approximately the `SoftmaxNormal-Distribution` compound distribution, it is itself a valid distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which @@ -109,8 +111,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior") [compound distribution]( https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + Using variable-substitution and [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can redefine the distribution to be a parameter-less convex combination of `K` different affine combinations of a `d` iid samples from `distribution`. @@ -141,7 +145,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): and, ```none - grid, weight = np.polynomial.hermite.hermgauss(quadrature_degree) + grid, weight = np.polynomial.hermite.hermgauss(quadrature_size) prob[k] = weight[k] / sqrt(pi) lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i]) ``` @@ -248,9 +252,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_grid_and_probs: Python pair of `list`-like objects representing - the sample points and the corresponding (possibly normalized) weight. - When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -317,24 +322,17 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - if quadrature_grid_and_probs is None: - grid, probs = np.polynomial.hermite.hermgauss(deg=8) - else: - grid, probs = tuple(quadrature_grid_and_probs) - if len(grid) != len(probs): - raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " - "same-length list-like objects") - grid = grid.astype(dtype.as_numpy_dtype) - probs = probs.astype(dtype.as_numpy_dtype) - probs /= np.linalg.norm(probs, ord=1) + grid, probs = distribution_util.process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args) self._quadrature_grid = grid self._quadrature_probs = probs + self._quadrature_size = distribution_util.dimension_size(probs, axis=0) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(probs), + logits=math_ops.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -361,10 +359,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(len(self._quadrature_grid), + interpolate_loc(self._quadrature_size, self._interpolate_weight, loc), - interpolate_scale(len(self._quadrature_grid), + interpolate_scale(self._quadrature_size, self._interpolate_weight, scale)))] @@ -463,10 +461,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - # Stride `quadrature_degree` for `batch_size` number of times. + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * len(self.quadrature_probs), - delta=len(self.quadrature_probs), + limit=batch_size * self._quadrature_size, + delta=self._quadrature_size, dtype=ids.dtype) weight = array_ops.gather( diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index f261d996b54..41b86f79409 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -1049,13 +1050,77 @@ def dimension_size(x, axis): """Returns the size of a specific dimension.""" # Since tf.gather isn't "constant-in, constant-out", we must first check the # static shape or fallback to dynamic shape. - num_rows = (None if x.get_shape().ndims is None - else x.get_shape()[axis].value) - if num_rows is not None: - return num_rows + s = x.shape.with_rank_at_least(axis + 1)[axis].value + if axis > -1 and s is not None: + return s return array_ops.shape(x)[axis] +def process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args, name=None): + """Validates quadrature grid, probs or computes them as necessary. + + Args: + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. + dtype: The expected `dtype` of `grid` and `probs`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. + + Raises: + ValueError: if `quadrature_grid_and_probs is not None` and + `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` + """ + with ops.name_scope(name, "process_quadrature_grid_and_probs", + [quadrature_grid_and_probs]): + if quadrature_grid_and_probs is None: + grid, probs = np.polynomial.hermite.hermgauss(deg=8) + grid = grid.astype(dtype.as_numpy_dtype) + probs = probs.astype(dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) + probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) + return grid, probs + + grid, probs = tuple(quadrature_grid_and_probs) + grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) + probs = ops.convert_to_tensor(probs, name="unnormalized_probs", + dtype=dtype) + probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True, + name="probs") + + def _static_dim_size(x, axis): + """Returns the static size of a specific dimension or `None`.""" + return x.shape.with_rank_at_least(axis + 1)[axis].value + + m, n = _static_dim_size(probs, axis=0), _static_dim_size(grid, axis=0) + if m is not None and n is not None: + if m != n: + raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " + "same-length zero-th-dimension `Tensor`s " + "(saw lengths {}, {})".format(m, n)) + elif validate_args: + grid = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + dimension_size(probs, axis=0), + dimension_size(grid, axis=0), + message=("`quadrature_grid_and_probs` must be a `tuple` of " + "same-length zero-th-dimension `Tensor`s")), + ], grid) + + return grid, probs + + class AppendDocstring(object): """Helper class to promote private subclass docstring to public counterpart.