From 4f60ddb470905455e74d328a4efd2528675f8f53 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Fri, 28 Jul 2017 09:50:13 -0700 Subject: [PATCH] Add Poisson-LogNormal (approximate) compound distribution. PiperOrigin-RevId: 163480957 --- tensorflow/contrib/distributions/BUILD | 20 +- tensorflow/contrib/distributions/__init__.py | 2 + .../kernel_tests/poisson_lognormal_test.py | 92 +++++ .../kernel_tests/vector_diffeomixture_test.py | 226 +---------- .../python/ops/poisson_lognormal.py | 313 +++++++++++++++ .../distributions/python/ops/test_util.py | 378 ++++++++++++++++++ .../python/ops/vector_diffeomixture.py | 19 +- 7 files changed, 815 insertions(+), 235 deletions(-) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/poisson_lognormal.py create mode 100644 tensorflow/contrib/distributions/python/ops/test_util.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 94e5c3785b9..fa8bb9a45c3 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -44,6 +44,7 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:histogram_ops", "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", @@ -327,6 +328,17 @@ cuda_py_test( ], ) +cuda_py_test( + name = "poisson_lognormal_test", + size = "small", + srcs = ["python/kernel_tests/poisson_lognormal_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sample_stats_test", size = "small", @@ -518,19 +530,13 @@ cuda_py_test( cuda_py_test( name = "vector_diffeomixture_test", - size = "large", + size = "small", srcs = ["python/kernel_tests/vector_diffeomixture_test.py"], additional_deps = [ - ":bijectors_py", ":distributions_py", "//third_party/py/numpy", "//tensorflow/contrib/linalg:linalg_py", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 7a2aebddd25..3bbf1c2f5e2 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -44,6 +44,7 @@ from tensorflow.contrib.distributions.python.ops.negative_binomial import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.onehot_categorical import * from tensorflow.contrib.distributions.python.ops.poisson import * +from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * @@ -117,6 +118,7 @@ _allowed_symbols = [ 'Normal', 'NormalWithSoftplusScale', 'Poisson', + 'PoissonLogNormalQuadratureCompound', 'StudentT', 'StudentTWithAbsDfSoftplusScale', 'Uniform', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py new file mode 100644 index 00000000000..7cb46bb2367 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -0,0 +1,92 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PoissonLogNormalQuadratureCompoundTest.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import poisson_lognormal +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.platform import test + + +class PoissonLogNormalQuadratureCompoundTest( + test_util.DiscreteScalarDistributionTestHelpers, test.TestCase): + """Tests the PoissonLogNormalQuadratureCompoundTest distribution.""" + + def testSampleProbConsistent(self): + with self.test_session() as sess: + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=-2., + scale=1.1, + quadrature_polynomial_degree=10, + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess, pln, rtol=0.1) + + def testMeanVariance(self): + with self.test_session() as sess: + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=0., + scale=1., + quadrature_polynomial_degree=10, + validate_args=True) + self.run_test_sample_consistent_mean_variance( + sess, pln, rtol=0.02) + + def testSampleProbConsistentBroadcastScalar(self): + with self.test_session() as sess: + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=[0., -0.5], + scale=1., + quadrature_polynomial_degree=10, + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess, pln, rtol=0.1, atol=0.01) + + def testMeanVarianceBroadcastScalar(self): + with self.test_session() as sess: + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=[0., -0.5], + scale=1., + quadrature_polynomial_degree=10, + validate_args=True) + self.run_test_sample_consistent_mean_variance( + sess, pln, rtol=0.1, atol=0.01) + + def testSampleProbConsistentBroadcastBoth(self): + with self.test_session() as sess: + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=[[0.], [-0.5]], + scale=[[1., 0.9]], + quadrature_polynomial_degree=10, + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess, pln, rtol=0.1, atol=0.08) + + def testMeanVarianceBroadcastBoth(self): + with self.test_session() as sess: + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=[[0.], [-0.5]], + scale=[[1., 0.9]], + quadrature_polynomial_degree=10, + validate_args=True) + self.run_test_sample_consistent_mean_variance( + sess, pln, rtol=0.1, atol=0.01) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index 62ffbea1b5c..0825d572ccb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -20,236 +20,16 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -class VectorDistributionTestHelpers(object): - """VectorDistributionTestHelpers helps test vector-event distributions.""" - - def linop(self, num_rows=None, multiplier=None, diag=None): - """Helper to create non-singular, symmetric, positive definite matrices.""" - if num_rows is not None and multiplier is not None: - if any(p is not None for p in [diag]): - raise ValueError("Found extra args for scaled identity.") - return linop_identity_lib.LinearOperatorScaledIdentity( - num_rows=num_rows, - multiplier=multiplier, - is_positive_definite=True) - elif num_rows is not None: - if any(p is not None for p in [multiplier, diag]): - raise ValueError("Found extra args for identity.") - return linop_identity_lib.LinearOperatorIdentity( - num_rows=num_rows, - is_positive_definite=True) - elif diag is not None: - if any(p is not None for p in [num_rows, multiplier]): - raise ValueError("Found extra args for diag.") - return linop_diag_lib.LinearOperatorDiag( - diag=diag, - is_positive_definite=True) - else: - raise ValueError("Must specify at least one arg.") - - def run_test_sample_consistent_log_prob( - self, - sess, - dist, - num_samples=int(1e5), - radius=1., - center=0., - seed=42, - rtol=1e-2, - atol=0.): - """Tests that sample/log_prob are mutually consistent. - - "Consistency" means that `sample` and `log_prob` correspond to the same - distribution. - - The idea of this test is to compute the Monte-Carlo estimate of the volume - enclosed by a hypersphere, i.e., the volume of an `n`-ball. While we could - choose an arbitrary function to integrate, the hypersphere's volume is nice - because it is intuitive, has an easy analytical expression, and works for - `dimensions > 1`. - - Technical Details: - - Observe that: - - ```none - int_{R**d} dx [x in Ball(radius=r, center=c)] - = E_{p(X)}[ [X in Ball(r, c)] / p(X) ] - = lim_{m->infty} m**-1 sum_j^m [x[j] in Ball(r, c)] / p(x[j]), - where x[j] ~iid p(X) - ``` - - Thus, for fixed `m`, the above is approximately true when `sample` and - `log_prob` are mutually consistent. - - Furthermore, the above calculation has the analytical result: - `pi**(d/2) r**d / Gamma(1 + d/2)`. - - Note: this test only verifies a necessary condition for consistency--it does - does not verify sufficiency hence does not prove `sample`, `log_prob` truly - are consistent. For this reason we recommend testing several different - hyperspheres (assuming the hypersphere is supported by the distribution). - Furthermore, we gain additional trust in this test when also tested `sample` - against the first, second moments - (`run_test_sample_consistent_mean_covariance`); it is probably unlikely that - a "best-effort" implementation of `log_prob` would incorrectly pass both - tests and for different hyperspheres. - - For a discussion on the analytical result (second-line) see: - https://en.wikipedia.org/wiki/Volume_of_an_n-ball. - - For a discussion of importance sampling (fourth-line) see: - https://en.wikipedia.org/wiki/Importance_sampling. - - Args: - sess: Tensorflow session. - dist: Distribution instance or object which implements `sample`, - `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The - distribution must have non-zero probability of sampling every point - enclosed by the hypersphere. - num_samples: Python `int` scalar indicating the number of Monte-Carlo - samples to draw from `dist`. - radius: Python `float`-type indicating the radius of the `n`-ball which - we're computing the volume. - center: Python floating-type vector (or scalar) indicating the center of - the `n`-ball which we're computing the volume. When scalar, the value is - broadcast to all event dims. - seed: Python `int` indicating the seed to use when sampling from `dist`. - In general it is not recommended to use `None` during a test as this - increases the likelihood of spurious test failure. - rtol: Python `float`-type indicating the admissible relative error between - actual- and approximate-volumes. - atol: Python `float`-type indicating the admissible absolute error between - actual- and approximate-volumes. In general this should be zero since - a typical radius implies a non-zero volume. - """ - - def actual_hypersphere_volume(dims, radius): - # https://en.wikipedia.org/wiki/Volume_of_an_n-ball - # Using tf.lgamma because we'd have to otherwise use SciPy which is not - # a required dependency of core. - radius = np.asarray(radius) - dims = math_ops.cast(dims, dtype=radius.dtype) - return math_ops.exp( - (dims / 2.) * np.log(np.pi) - - math_ops.lgamma(1. + dims / 2.) - + dims * math_ops.log(radius)) - - def is_in_ball(x, radius, center): - return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius, - dtype=x.dtype) - - def monte_carlo_hypersphere_volume(dist, num_samples, radius, center): - # https://en.wikipedia.org/wiki/Importance_sampling - x = dist.sample(num_samples, seed=seed) - return math_ops.reduce_mean( - math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center), - axis=0) - - [ - batch_shape_, - actual_volume_, - sample_volume_, - ] = sess.run([ - dist.batch_shape_tensor(), - actual_hypersphere_volume( - dims=dist.event_shape_tensor()[0], - radius=radius), - monte_carlo_hypersphere_volume( - dist, - num_samples=num_samples, - radius=radius, - center=center), - ]) - - self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_), - sample_volume_, - rtol=rtol, atol=atol) - - def run_test_sample_consistent_mean_covariance( - self, - sess, - dist, - num_samples=int(1e5), - seed=24, - rtol=1e-2, - atol=0., - cov_rtol=None, - cov_atol=None): - """Tests that sample/mean/covariance are consistent with each other. - - "Consistency" means that `sample`, `mean`, `covariance`, etc all correspond - to the same distribution. - - Args: - sess: Tensorflow session. - dist: Distribution instance or object which implements `sample`, - `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. - num_samples: Python `int` scalar indicating the number of Monte-Carlo - samples to draw from `dist`. - seed: Python `int` indicating the seed to use when sampling from `dist`. - In general it is not recommended to use `None` during a test as this - increases the likelihood of spurious test failure. - rtol: Python `float`-type indicating the admissible relative error between - analytical and sample statistics. - atol: Python `float`-type indicating the admissible absolute error between - analytical and sample statistics. - cov_rtol: Python `float`-type indicating the admissible relative error - between analytical and sample covariance. Default: rtol. - cov_atol: Python `float`-type indicating the admissible absolute error - between analytical and sample covariance. Default: atol. - """ - - def vec_osquare(x): - """Computes the outer-product of a vector, i.e., x.T x.""" - return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] - - x = dist.sample(num_samples, seed=seed) - sample_mean = math_ops.reduce_mean(x, axis=0) - sample_covariance = math_ops.reduce_mean( - vec_osquare(x - sample_mean), axis=0) - sample_variance = array_ops.matrix_diag_part(sample_covariance) - sample_stddev = math_ops.sqrt(sample_variance) - - [ - sample_mean_, - sample_covariance_, - sample_variance_, - sample_stddev_, - mean_, - covariance_, - variance_, - stddev_ - ] = sess.run([ - sample_mean, - sample_covariance, - sample_variance, - sample_stddev, - dist.mean(), - dist.covariance(), - dist.variance(), - dist.stddev(), - ]) - - self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol) - self.assertAllClose(covariance_, sample_covariance_, - rtol=cov_rtol or rtol, - atol=cov_atol or atol) - self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol) - self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol) - - -class VectorDiffeomixtureTest(VectorDistributionTestHelpers, test.TestCase): +class VectorDiffeomixtureTest( + test_util.VectorDistributionTestHelpers, test.TestCase): """Tests the VectorDiffeomixture distribution.""" def testSampleProbConsistentBroadcastMix(self): diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py new file mode 100644 index 00000000000..1c2046c7f03 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -0,0 +1,313 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The PoissonLogNormalQuadratureCompound distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import categorical as categorical_lib +from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +__all__ = [ + "PoissonLogNormalQuadratureCompound", +] + + +class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): + """`PoissonLogNormalQuadratureCompound` distribution. + + The `PoissonLogNormalQuadratureCompound` is an approximation to a + Poisson-LogNormal [compound distribution]( + https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., + + ```none + p(k|loc, scale) + = int_{R_+} dl LogNormal(l | loc, scale) Poisson(k | l) + = int_{R} dz ((lambda(z) sqrt(2) scale) + * exp(-z**2) / (lambda(z) sqrt(2 pi) sigma) + * Poisson(k | lambda(z))) + = int_{R} dz exp(-z**2) / sqrt(pi) Poisson(k | lambda(z)) + approx= sum{ prob[d] Poisson(k | lambda(grid[d])) : d=0, ..., deg-1 } + ``` + + where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms + are from [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that + the second line made the substitution: + `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] + and `dl = sqrt(2) scale lambda(z) dz` + + In the non-approximation case, a draw from the LogNormal prior represents the + Poisson rate parameter. Unfortunately, the non-approximate distribution lacks + an analytical probability density function (pdf). Therefore the + `PoissonLogNormalQuadratureCompound` class implements an approximation based + on [Gauss-Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). + Note: although the `PoissonLogNormalQuadratureCompound` is approximately the + Poisson-LogNormal compound distribution, it is itself a valid distribution. + Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are + all mutually consistent. + + #### Mathematical Details + + The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal + [compound distribution]( + https://en.wikipedia.org/wiki/Compound_probability_distribution). + Using variable-substitution and [Gauss-Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + redefine the distribution to be a parameter-less convex combination of `deg` + different Poisson samples. + + That is, defined over positive integers, this distribution is parameterized + by a (batch of) `loc` and `scale` scalars. + + The probability density function (pdf) is, + + ```none + pdf(k | loc, scale, deg) + = sum{ prob[d] Poisson(k | lambda=exp(sqrt(2) scale grid[d] + loc)) + : d=0, ..., deg-1 } + ``` + + where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( + https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) + and `prob = w / sqrt(pi)`. + + #### Examples + + ```python + ds = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with + # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` + pln = ds.PoissonLogNormalQuadratureCompound( + loc=[0., -0.5], + scale=1., + quadrature_polynomial_degree=10, + validate_args=True) + """ + + def __init__(self, + loc, + scale, + quadrature_polynomial_degree=8, + validate_args=False, + allow_nan_stats=True, + name="PoissonLogNormalQuadratureCompound"): + """Constructs the PoissonLogNormalQuadratureCompound on `R**k`. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_polynomial_degree: Python `int`-like scalar. + Default value: 8. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + TypeError: if `loc.dtype != scale[0].dtype`. + """ + parameters = locals() + with ops.name_scope(name, values=[loc, scale]): + loc = ops.convert_to_tensor(loc, name="loc") + self._loc = loc + + scale = ops.convert_to_tensor(scale, name="scale") + self._scale = scale + + dtype = loc.dtype.base_dtype + if dtype != scale.dtype.base_dtype: + raise TypeError( + "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( + loc.dtype.name, scale.dtype.name)) + + self._degree = quadrature_polynomial_degree + + grid, prob = np.polynomial.hermite.hermgauss( + deg=quadrature_polynomial_degree) + + # It should be that `sum(prob) == sqrt(pi)`, but self-normalization is + # more numerically stable. + prob = prob.astype(dtype.as_numpy_dtype) + prob /= np.linalg.norm(prob, ord=1) + + self._mixture_distribution = categorical_lib.Categorical( + logits=np.log(prob), + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) + + # The following maps the broadcast of `loc` and `scale` to each grid + # point, i.e., we are creating several log-rates that correspond to the + # different Gauss-Hermite quadrature points and (possible) batches of + # `loc` and `scale`. + self._log_rate = (loc[..., array_ops.newaxis] + + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) + + self._distribution = poisson_lib.Poisson( + rate=math_ops.exp(self._log_rate, name="rate"), + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) + + super(PoissonLogNormalQuadratureCompound, self).__init__( + dtype=dtype, + reparameterization_type=distribution_lib.NOT_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[loc, scale], + name=name) + + @property + def mixture_distribution(self): + """Distribution which randomly selects a Poisson with Gauss-Hermite rate.""" + return self._mixture_distribution + + @property + def distribution(self): + """Base Poisson parameterized by a Gauss-Hermite grid of rates.""" + return self._distribution + + @property + def loc(self): + """Location parameter of the LogNormal prior.""" + return self._loc + + @property + def scale(self): + """Scale parameter of the LogNormal prior.""" + return self._scale + + @property + def quadrature_polynomial_degree(self): + """Polynomial largest exponent used for Gauss-Hermite quadrature.""" + return self._degree + + def _batch_shape_tensor(self): + return array_ops.broadcast_dynamic_shape( + array_ops.shape(self.loc), + array_ops.shape(self.scale)) + + def _batch_shape(self): + return array_ops.broadcast_static_shape( + self.loc.shape, + self.scale.shape) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get + # ids as a [n]-shaped vector. + batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) + if self.batch_shape.is_fully_defined() + else math_ops.reduce_prod(self.batch_shape_tensor())) + ids = self._mixture_distribution.sample( + sample_shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + [batch_size])), + seed=distribution_util.gen_new_seed( + seed, "poisson_lognormal_quadrature_compound")) + # Stride `quadrature_polynomial_degree` for `batch_size` number of times. + offset = math_ops.range(start=0, + limit=batch_size * self._degree, + delta=self._degree, + dtype=ids.dtype) + ids += offset + rate = array_ops.gather( + array_ops.reshape(self.distribution.rate, shape=[-1]), ids) + rate = array_ops.reshape( + rate, shape=concat_vectors([n], self.batch_shape_tensor())) + return random_ops.random_poisson( + lam=rate, shape=[], dtype=self.dtype, seed=seed) + + def _log_prob(self, x): + return math_ops.reduce_logsumexp( + (self.mixture_distribution.logits + + self.distribution.log_prob(x[..., array_ops.newaxis])), + axis=-1) + + def _mean(self): + return math_ops.exp( + math_ops.reduce_logsumexp( + self.mixture_distribution.logits + self._log_rate, + axis=-1)) + + def _variance(self): + return math_ops.exp(self._log_variance()) + + def _stddev(self): + return math_ops.exp(0.5 * self._log_variance()) + + def _log_variance(self): + # Following calculation is based on law of total variance: + # + # Var[Z] = E[Var[Z | V]] + Var[E[Z | V]] + # + # where, + # + # Z|v ~ interpolate_affine[v](distribution) + # V ~ mixture_distrubution + # + # thus, + # + # E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 } + # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 } + v = array_ops.stack([ + # log(self.distribution.variance()) = log(Var[d]) = log(rate[d]) + self._log_rate, + # log((Mean[d] - Mean)**2) + 2. * math_ops.log( + math_ops.abs(self.distribution.mean() + - self._mean()[..., array_ops.newaxis])), + ], axis=-1) + return math_ops.reduce_logsumexp( + self.mixture_distribution.logits[..., array_ops.newaxis] + v, + axis=[-2, -1]) + + +def static_value(x): + """Returns the static value of a `Tensor` or `None`.""" + return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def concat_vectors(*args): + """Concatenates input vectors, statically if possible.""" + args_ = [static_value(x) for x in args] + if any(vec is None for vec in args_): + return array_ops.concat(args, axis=0) + return [val for vec in args_ for val in vec] diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py new file mode 100644 index 00000000000..da7d3907acb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -0,0 +1,378 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for testing distributions and/or bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import histogram_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + "DiscreteScalarDistributionTestHelpers", + "VectorDistributionTestHelpers", +] + + +class DiscreteScalarDistributionTestHelpers(object): + """DiscreteScalarDistributionTestHelpers.""" + + def run_test_sample_consistent_log_prob( + self, sess, dist, + num_samples=int(1e5), num_threshold=int(1e3), seed=42, + rtol=1e-2, atol=0.): + """Tests that sample/log_prob are consistent with each other. + + "Consistency" means that `sample` and `log_prob` correspond to the same + distribution. + + Note: this test only verifies a necessary condition for consistency--it does + does not verify sufficiency hence does not prove `sample`, `log_prob` truly + are consistent. + + Args: + sess: Tensorflow session. + dist: Distribution instance or object which implements `sample`, + `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. + num_samples: Python `int` scalar indicating the number of Monte-Carlo + samples to draw from `dist`. + num_threshold: Python `int` scalar indicating the number of samples a + bucket must contain before being compared to the probability. + Default value: 1e3; must be at least 1. + Warning, set too high will cause test to falsely pass but setting too + low will cause the test to falsely fail. + seed: Python `int` indicating the seed to use when sampling from `dist`. + In general it is not recommended to use `None` during a test as this + increases the likelihood of spurious test failure. + rtol: Python `float`-type indicating the admissible relative error between + analytical and sample statistics. + atol: Python `float`-type indicating the admissible absolute error between + analytical and sample statistics. + + Raises: + ValueError: if `num_threshold < 1`. + """ + if num_threshold < 1: + raise ValueError("num_threshold({}) must be at least 1.".format( + num_threshold)) + # Histogram only supports vectors so we call it once per batch coordinate. + y = dist.sample(num_samples, seed=seed) + y = array_ops.reshape(y, shape=[num_samples, -1]) + batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) + batch_dims = array_ops.shape(dist.batch_shape_tensor())[0] + edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]]) + for b, x in enumerate(array_ops.unstack(y, axis=1)): + counts, edges = self.histogram(x) + edges = array_ops.reshape(edges, edges_expanded_shape) + probs = math_ops.exp(dist.log_prob(edges)) + probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b] + + [counts_, probs_] = sess.run([counts, probs]) + valid = counts_ > num_threshold + probs_ = probs_[valid] + counts_ = counts_[valid] + self.assertAllClose(probs_, counts_ / num_samples, + rtol=rtol, atol=atol) + + def run_test_sample_consistent_mean_variance( + self, sess, dist, + num_samples=int(1e5), seed=24, + rtol=1e-2, atol=0.): + """Tests that sample/mean/variance are consistent with each other. + + "Consistency" means that `sample`, `mean`, `variance`, etc all correspond + to the same distribution. + + Args: + sess: Tensorflow session. + dist: Distribution instance or object which implements `sample`, + `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. + num_samples: Python `int` scalar indicating the number of Monte-Carlo + samples to draw from `dist`. + seed: Python `int` indicating the seed to use when sampling from `dist`. + In general it is not recommended to use `None` during a test as this + increases the likelihood of spurious test failure. + rtol: Python `float`-type indicating the admissible relative error between + analytical and sample statistics. + atol: Python `float`-type indicating the admissible absolute error between + analytical and sample statistics. + """ + x = math_ops.to_float(dist.sample(num_samples, seed=seed)) + sample_mean = math_ops.reduce_mean(x, axis=0) + sample_variance = math_ops.reduce_mean( + math_ops.square(x - sample_mean), axis=0) + sample_stddev = math_ops.sqrt(sample_variance) + + [ + sample_mean_, + sample_variance_, + sample_stddev_, + mean_, + variance_, + stddev_ + ] = sess.run([ + sample_mean, + sample_variance, + sample_stddev, + dist.mean(), + dist.variance(), + dist.stddev(), + ]) + + self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol) + self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol) + self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol) + + def histogram(self, x, value_range=None, nbins=None, name=None): + """Return histogram of values. + + Given the tensor `values`, this operation returns a rank 1 histogram + counting the number of entries in `values` that fell into every bin. The + bins are equal width and determined by the arguments `value_range` and + `nbins`. + + Args: + x: 1D numeric `Tensor` of items to count. + value_range: Shape [2] `Tensor`. `new_values <= value_range[0]` will be + mapped to `hist[0]`, `values >= value_range[1]` will be mapped to + `hist[-1]`. Must be same dtype as `x`. + nbins: Scalar `int32 Tensor`. Number of histogram bins. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + counts: 1D `Tensor` of counts, i.e., + `counts[i] = sum{ edges[i-1] <= values[j] < edges[i] : j }`. + edges: 1D `Tensor` characterizing intervals used for counting. + """ + with ops.name_scope(name, "histogram", [x]): + x = ops.convert_to_tensor(x, name="x") + if value_range is None: + value_range = [math_ops.reduce_min(x), 1 + math_ops.reduce_max(x)] + value_range = ops.convert_to_tensor(value_range, name="value_range") + lo = value_range[0] + hi = value_range[1] + if nbins is None: + nbins = math_ops.to_int32(hi - lo) + delta = (hi - lo) / math_ops.cast( + nbins, dtype=value_range.dtype.base_dtype) + edges = math_ops.range( + start=lo, limit=hi, delta=delta, dtype=x.dtype.base_dtype) + counts = histogram_ops.histogram_fixed_width( + x, value_range=value_range, nbins=nbins) + return counts, edges + + +class VectorDistributionTestHelpers(object): + """VectorDistributionTestHelpers helps test vector-event distributions.""" + + def run_test_sample_consistent_log_prob( + self, + sess, + dist, + num_samples=int(1e5), + radius=1., + center=0., + seed=42, + rtol=1e-2, + atol=0.): + """Tests that sample/log_prob are mutually consistent. + + "Consistency" means that `sample` and `log_prob` correspond to the same + distribution. + + The idea of this test is to compute the Monte-Carlo estimate of the volume + enclosed by a hypersphere, i.e., the volume of an `n`-ball. While we could + choose an arbitrary function to integrate, the hypersphere's volume is nice + because it is intuitive, has an easy analytical expression, and works for + `dimensions > 1`. + + Technical Details: + + Observe that: + + ```none + int_{R**d} dx [x in Ball(radius=r, center=c)] + = E_{p(X)}[ [X in Ball(r, c)] / p(X) ] + = lim_{m->infty} m**-1 sum_j^m [x[j] in Ball(r, c)] / p(x[j]), + where x[j] ~iid p(X) + ``` + + Thus, for fixed `m`, the above is approximately true when `sample` and + `log_prob` are mutually consistent. + + Furthermore, the above calculation has the analytical result: + `pi**(d/2) r**d / Gamma(1 + d/2)`. + + Note: this test only verifies a necessary condition for consistency--it does + does not verify sufficiency hence does not prove `sample`, `log_prob` truly + are consistent. For this reason we recommend testing several different + hyperspheres (assuming the hypersphere is supported by the distribution). + Furthermore, we gain additional trust in this test when also tested `sample` + against the first, second moments + (`run_test_sample_consistent_mean_covariance`); it is probably unlikely that + a "best-effort" implementation of `log_prob` would incorrectly pass both + tests and for different hyperspheres. + + For a discussion on the analytical result (second-line) see: + https://en.wikipedia.org/wiki/Volume_of_an_n-ball. + + For a discussion of importance sampling (fourth-line) see: + https://en.wikipedia.org/wiki/Importance_sampling. + + Args: + sess: Tensorflow session. + dist: Distribution instance or object which implements `sample`, + `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The + distribution must have non-zero probability of sampling every point + enclosed by the hypersphere. + num_samples: Python `int` scalar indicating the number of Monte-Carlo + samples to draw from `dist`. + radius: Python `float`-type indicating the radius of the `n`-ball which + we're computing the volume. + center: Python floating-type vector (or scalar) indicating the center of + the `n`-ball which we're computing the volume. When scalar, the value is + broadcast to all event dims. + seed: Python `int` indicating the seed to use when sampling from `dist`. + In general it is not recommended to use `None` during a test as this + increases the likelihood of spurious test failure. + rtol: Python `float`-type indicating the admissible relative error between + actual- and approximate-volumes. + atol: Python `float`-type indicating the admissible absolute error between + actual- and approximate-volumes. In general this should be zero since + a typical radius implies a non-zero volume. + """ + + def actual_hypersphere_volume(dims, radius): + # https://en.wikipedia.org/wiki/Volume_of_an_n-ball + # Using tf.lgamma because we'd have to otherwise use SciPy which is not + # a required dependency of core. + radius = np.asarray(radius) + dims = math_ops.cast(dims, dtype=radius.dtype) + return math_ops.exp( + (dims / 2.) * np.log(np.pi) + - math_ops.lgamma(1. + dims / 2.) + + dims * math_ops.log(radius)) + + def is_in_ball(x, radius, center): + return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius, + dtype=x.dtype) + + def monte_carlo_hypersphere_volume(dist, num_samples, radius, center): + # https://en.wikipedia.org/wiki/Importance_sampling + x = dist.sample(num_samples, seed=seed) + return math_ops.reduce_mean( + math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center), + axis=0) + + [ + batch_shape_, + actual_volume_, + sample_volume_, + ] = sess.run([ + dist.batch_shape_tensor(), + actual_hypersphere_volume( + dims=dist.event_shape_tensor()[0], + radius=radius), + monte_carlo_hypersphere_volume( + dist, + num_samples=num_samples, + radius=radius, + center=center), + ]) + + self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_), + sample_volume_, + rtol=rtol, atol=atol) + + def run_test_sample_consistent_mean_covariance( + self, + sess, + dist, + num_samples=int(1e5), + seed=24, + rtol=1e-2, + atol=0., + cov_rtol=None, + cov_atol=None): + """Tests that sample/mean/covariance are consistent with each other. + + "Consistency" means that `sample`, `mean`, `covariance`, etc all correspond + to the same distribution. + + Args: + sess: Tensorflow session. + dist: Distribution instance or object which implements `sample`, + `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. + num_samples: Python `int` scalar indicating the number of Monte-Carlo + samples to draw from `dist`. + seed: Python `int` indicating the seed to use when sampling from `dist`. + In general it is not recommended to use `None` during a test as this + increases the likelihood of spurious test failure. + rtol: Python `float`-type indicating the admissible relative error between + analytical and sample statistics. + atol: Python `float`-type indicating the admissible absolute error between + analytical and sample statistics. + cov_rtol: Python `float`-type indicating the admissible relative error + between analytical and sample covariance. Default: rtol. + cov_atol: Python `float`-type indicating the admissible absolute error + between analytical and sample covariance. Default: atol. + """ + + x = dist.sample(num_samples, seed=seed) + sample_mean = math_ops.reduce_mean(x, axis=0) + sample_covariance = math_ops.reduce_mean( + _vec_outer_square(x - sample_mean), axis=0) + sample_variance = array_ops.matrix_diag_part(sample_covariance) + sample_stddev = math_ops.sqrt(sample_variance) + + [ + sample_mean_, + sample_covariance_, + sample_variance_, + sample_stddev_, + mean_, + covariance_, + variance_, + stddev_ + ] = sess.run([ + sample_mean, + sample_covariance, + sample_variance, + sample_stddev, + dist.mean(), + dist.covariance(), + dist.variance(), + dist.stddev(), + ]) + + self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol) + self.assertAllClose(covariance_, sample_covariance_, + rtol=cov_rtol or rtol, + atol=cov_atol or atol) + self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol) + self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol) + + +def _vec_outer_square(x, name=None): + """Computes the outer-product of a vector, i.e., x.T x.""" + with ops.name_scope(name, "vec_osquare", [x]): + return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 282683ef39d..0a9662ed754 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -221,7 +221,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): quadrature_polynomial_degree=8, validate_args=False, allow_nan_stats=True, - name="VectorLocationScaleDiffeomixture"): + name="VectorDiffeomixture"): """Constructs the VectorDiffeomixture on `R**k`. Args: @@ -387,24 +387,34 @@ class VectorDiffeomixture(distribution_lib.Distribution): @property def mixture_distribution(self): + """Distribution used to select a convex combination of affine transforms.""" return self._mixture_distribution @property def distribution(self): + """Base scalar-event, scalar-batch distribution.""" return self._distribution @property def interpolate_weight(self): + """Grid of mixing probabilities, one for each grid point.""" return self._interpolate_weight @property def endpoint_affine(self): + """Affine transformation for each of `K` components.""" return self._endpoint_affine @property def interpolated_affine(self): + """Affine transformation for each convex combination of `K` components.""" return self._interpolated_affine + @property + def quadrature_polynomial_degree(self): + """Polynomial largest exponent used for Gauss-Hermite quadrature.""" + return self._degree + def _batch_shape_tensor(self): return self._batch_shape_ @@ -457,12 +467,12 @@ class VectorDiffeomixture(distribution_lib.Distribution): # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] - x = weight * (x[0] - x[1]) + array_ops.ones_like(x[0]) * x[1] + x = weight * (x[0] - x[1]) + x[1] return x def _log_prob(self, x): - # By convention, we always put the the grid points right-most. + # By convention, we always put the grid points right-most. y = array_ops.stack( [aff.inverse(x) for aff in self.interpolated_affine], axis=-1) @@ -740,8 +750,7 @@ def interpolate_loc(deg, interpolate_weight, loc): x = interpolate_weight[..., array_ops.newaxis] * loc[0] else: delta = loc[0] - loc[1] - offset = array_ops.ones_like(loc[0]) * loc[1] - x = interpolate_weight[..., array_ops.newaxis] * delta + offset + x = interpolate_weight[..., array_ops.newaxis] * delta + loc[1] return [x[..., k, :] for k in range(deg)]