diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3fd428e1220..2d5a708bac6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -99,7 +99,16 @@ cuda_py_tests( srcs = ["python/kernel_tests/beta_test.py"], additional_deps = [ ":distributions_py", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "binomial_test", + size = "small", + srcs = ["python/kernel_tests/binomial_test.py"], + additional_deps = [ + ":distributions_py", "//tensorflow/python:platform_test", ], tags = ["notsan"], @@ -179,9 +188,8 @@ cuda_py_tests( ) cuda_py_tests( - name = "kullback_leibler_test", - size = "small", - srcs = ["python/kernel_tests/kullback_leibler_test.py"], + name = "laplace_test", + srcs = ["python/kernel_tests/laplace_test.py"], additional_deps = [ ":distributions_py", "//tensorflow/python:framework_test_lib", @@ -190,13 +198,14 @@ cuda_py_tests( ) cuda_py_tests( - name = "laplace_test", - srcs = ["python/kernel_tests/laplace_test.py"], + name = "multinomial_test", + srcs = ["python/kernel_tests/multinomial_test.py"], additional_deps = [ ":distributions_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = ["notsan"], ) cuda_py_tests( @@ -239,6 +248,15 @@ cuda_py_tests( srcs = ["python/kernel_tests/uniform_test.py"], additional_deps = [ ":distributions_py", + "//tensorflow/python:framework_test_lib", + ], +) + +cuda_py_tests( + name = "kullback_leibler_test", + size = "small", + srcs = ["python/kernel_tests/kullback_leibler_test.py"], + additional_deps = [ "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 2b32556f3eb..83719157761 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -25,6 +25,7 @@ initialized with parameters that define the distributions. ### Univariate (scalar) distributions +@@Binomial @@Bernoulli @@Beta @@Categorical @@ -50,6 +51,7 @@ initialized with parameters that define the distributions. @@Dirichlet @@DirichletMultinomial +@@Multinomial ### Transformed distributions @@ -79,6 +81,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bernoulli import * from tensorflow.contrib.distributions.python.ops.beta import * +from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.categorical import * from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.dirichlet import * @@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.kullback_leibler import * from tensorflow.contrib.distributions.python.ops.laplace import * +from tensorflow.contrib.distributions.python.ops.multinomial import * from tensorflow.contrib.distributions.python.ops.mvn import * from tensorflow.contrib.distributions.python.ops.normal import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py index c636a4d060c..82f77fbfd1e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py @@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase): self.assertAllClose(scipy.special.logit(p), dist.logits.eval()) def testInvalidP(self): - invalid_ps = [1.01, -0.01, 2., -3.] + invalid_ps = [1.01, 2.] for p in invalid_ps: with self.test_session(): - with self.assertRaisesOpError("x <= y"): + with self.assertRaisesOpError("p has components greater than 1"): + dist = tf.contrib.distributions.Bernoulli(p=p) + dist.p.eval() + + invalid_ps = [-0.01, -3.] + for p in invalid_ps: + with self.test_session(): + with self.assertRaisesOpError("Condition x >= 0"): dist = tf.contrib.distributions.Bernoulli(p=p) dist.p.eval() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py new file mode 100644 index 00000000000..8b2520f8368 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py @@ -0,0 +1,173 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class BinomialTest(tf.test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + p = np.float32(np.random.beta(1, 1)) + binom = tf.contrib.distributions.Binomial(n=1., p=p) + self.assertAllEqual([], binom.event_shape().eval()) + self.assertAllEqual([], binom.batch_shape().eval()) + self.assertEqual(tf.TensorShape([]), binom.get_event_shape()) + self.assertEqual(tf.TensorShape([]), binom.get_batch_shape()) + + def testComplexShapes(self): + with self.test_session(): + p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32) + n = [[3., 2], [4, 5], [6, 7]] + binom = tf.contrib.distributions.Binomial(n=n, p=p) + self.assertAllEqual([], binom.event_shape().eval()) + self.assertAllEqual([3, 2], binom.batch_shape().eval()) + self.assertEqual(tf.TensorShape([]), binom.get_event_shape()) + self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape()) + + def testNProperty(self): + p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]] + n = [[3.], [4]] + with self.test_session(): + binom = tf.contrib.distributions.Binomial(n=n, p=p) + self.assertEqual((2, 1), binom.n.get_shape()) + self.assertAllClose(n, binom.n.eval()) + + def testPProperty(self): + p = [[0.1, 0.2, 0.7]] + with self.test_session(): + binom = tf.contrib.distributions.Binomial(n=3., p=p) + self.assertEqual((1, 3), binom.p.get_shape()) + self.assertEqual((1, 3), binom.logits.get_shape()) + self.assertAllClose(p, binom.p.eval()) + + def testLogitsProperty(self): + logits = [[0., 9., -0.5]] + with self.test_session(): + binom = tf.contrib.distributions.Binomial(n=3., logits=logits) + self.assertEqual((1, 3), binom.p.get_shape()) + self.assertEqual((1, 3), binom.logits.get_shape()) + self.assertAllClose(logits, binom.logits.eval()) + + def testPmfNandCountsAgree(self): + p = [[0.1, 0.2, 0.7]] + n = [[5.]] + with self.test_session(): + binom = tf.contrib.distributions.Binomial(n=n, p=p) + binom.pmf([2., 3, 2]).eval() + binom.pmf([3., 1, 2]).eval() + with self.assertRaisesOpError('Condition x >= 0.*'): + binom.pmf([-1., 4, 2]).eval() + with self.assertRaisesOpError('Condition x <= y.*'): + binom.pmf([7., 3, 0]).eval() + + def testPmf_non_integer_counts(self): + p = [[0.1, 0.2, 0.7]] + n = [[5.]] + with self.test_session(): + # No errors with integer n. + binom = tf.contrib.distributions.Binomial(n=n, p=p) + binom.pmf([2., 3, 2]).eval() + binom.pmf([3., 1, 2]).eval() + # Both equality and integer checking fail. + with self.assertRaisesOpError('Condition x == y.*'): + binom.pmf([1.0, 2.5, 1.5]).eval() + + binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False) + binom.pmf([1., 2., 3.]).eval() + # Non-integer arguments work. + binom.pmf([1.0, 2.5, 1.5]).eval() + + def testPmfBothZeroBatches(self): + with self.test_session(): + # Both zero-batches. No broadcast + p = 0.5 + counts = 1. + pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts) + self.assertAllClose(0.5, pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + def testPmfBothZeroBatchesNontrivialN(self): + with self.test_session(): + # Both zero-batches. No broadcast + p = 0.1 + counts = 3. + binom = tf.contrib.distributions.Binomial(n=5., p=p) + pmf = binom.pmf(counts) + self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + def testPmfPStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + p = [[0.1, 0.9]] + counts = [[1., 2.]] + pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts) + self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval()) + self.assertEqual((1, 2), pmf.get_shape()) + + def testPmfPStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + p = [0.1, 0.4] + counts = [[1.], [0.]] + pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts) + self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval()) + self.assertEqual((2, 2), pmf.get_shape()) + + def testBinomialMean(self): + with self.test_session(): + n = 5. + p = [0.1, 0.2, 0.7] + binom = tf.contrib.distributions.Binomial(n=n, p=p) + expected_means = stats.binom.mean(n, p) + self.assertEqual((3,), binom.mean().get_shape()) + self.assertAllClose(expected_means, binom.mean().eval()) + + def testBinomialVariance(self): + with self.test_session(): + n = 5. + p = [0.1, 0.2, 0.7] + binom = tf.contrib.distributions.Binomial(n=n, p=p) + expected_variances = stats.binom.var(n, p) + self.assertEqual((3,), binom.variance().get_shape()) + self.assertAllClose(expected_variances, binom.variance().eval()) + + def testBinomialMode(self): + with self.test_session(): + n = 5. + p = [0.1, 0.2, 0.7] + binom = tf.contrib.distributions.Binomial(n=n, p=p) + expected_modes = [0., 1, 4] + self.assertEqual((3,), binom.mode().get_shape()) + self.assertAllClose(expected_modes, binom.mode().eval()) + + def testBinomialMultipleMode(self): + with self.test_session(): + n = 9. + p = [0.1, 0.2, 0.7] + binom = tf.contrib.distributions.Binomial(n=n, p=p) + # For the case where (n + 1) * p is an integer, the modes are: + # (n + 1) * p and (n + 1) * p - 1. In this case, we get back + # the larger of the two modes. + expected_modes = [1., 2, 7] + self.assertEqual((3,), binom.mode().get_shape()) + self.assertAllClose(expected_modes, binom.mode().eval()) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py index 1a3f5eaf66c..23833a246b9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -65,7 +65,7 @@ class DirichletMultinomialTest(tf.test.TestCase): dist.pmf([3., 0, 2]).eval() with self.assertRaisesOpError('Condition x >= 0.*'): dist.pmf([-1., 4, 2]).eval() - with self.assertRaisesOpError('Condition x == y.*'): + with self.assertRaisesOpError('counts do not sum to n'): dist.pmf([3., 3, 0]).eval() def testPmf_non_integer_counts(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py new file mode 100644 index 00000000000..55c7825bf3e --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py @@ -0,0 +1,226 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +class MultinomialTest(tf.test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + p = [.1, .3, .6] + dist = tf.contrib.distributions.Multinomial(n=1., p=p) + self.assertEqual(3, dist.event_shape().eval()) + self.assertAllEqual([], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([3]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([]), dist.get_batch_shape()) + + def testComplexShapes(self): + with self.test_session(): + p = 0.5 * np.ones([3, 2, 2], dtype=np.float32) + n = [[3., 2], [4, 5], [6, 7]] + dist = tf.contrib.distributions.Multinomial(n=n, p=p) + self.assertEqual(2, dist.event_shape().eval()) + self.assertAllEqual([3, 2], dist.batch_shape().eval()) + self.assertEqual(tf.TensorShape([2]), dist.get_event_shape()) + self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape()) + + def testNProperty(self): + p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]] + n = [[3.], [4]] + with self.test_session(): + dist = tf.contrib.distributions.Multinomial(n=n, p=p) + self.assertEqual((2, 1), dist.n.get_shape()) + self.assertAllClose(n, dist.n.eval()) + + def testPProperty(self): + p = [[0.1, 0.2, 0.7]] + with self.test_session(): + dist = tf.contrib.distributions.Multinomial(n=3., p=p) + self.assertEqual((1, 3), dist.p.get_shape()) + self.assertEqual((1, 3), dist.logits.get_shape()) + self.assertAllClose(p, dist.p.eval()) + + def testLogitsProperty(self): + logits = [[0., 9., -0.5]] + with self.test_session(): + multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits) + self.assertEqual((1, 3), multinom.p.get_shape()) + self.assertEqual((1, 3), multinom.logits.get_shape()) + self.assertAllClose(logits, multinom.logits.eval()) + + def testPmfNandCountsAgree(self): + p = [[0.1, 0.2, 0.7]] + n = [[5.]] + with self.test_session(): + dist = tf.contrib.distributions.Multinomial(n=n, p=p) + dist.pmf([2., 3, 0]).eval() + dist.pmf([3., 0, 2]).eval() + with self.assertRaisesOpError('Condition x >= 0.*'): + dist.pmf([-1., 4, 2]).eval() + with self.assertRaisesOpError('counts do not sum to n'): + dist.pmf([3., 3, 0]).eval() + + def testPmf_non_integer_counts(self): + p = [[0.1, 0.2, 0.7]] + n = [[5.]] + with self.test_session(): + # No errors with integer n. + multinom = tf.contrib.distributions.Multinomial(n=n, p=p) + multinom.pmf([2., 1, 2]).eval() + multinom.pmf([3., 0, 2]).eval() + # Counts don't sum to n. + with self.assertRaisesOpError('counts do not sum to n'): + multinom.pmf([2., 3, 2]).eval() + # Counts are non-integers. + with self.assertRaisesOpError('Condition x == y.*'): + multinom.pmf([1.0, 2.5, 1.5]).eval() + + multinom = tf.contrib.distributions.Multinomial( + n=n, p=p, validate_args=False) + multinom.pmf([1., 2., 2.]).eval() + # Non-integer arguments work. + multinom.pmf([1.0, 2.5, 1.5]).eval() + + def testPmfBothZeroBatches(self): + with self.test_session(): + # Both zero-batches. No broadcast + p = [0.5, 0.5] + counts = [1., 0] + pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts) + self.assertAllClose(0.5, pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + def testPmfBothZeroBatchesNontrivialN(self): + with self.test_session(): + # Both zero-batches. No broadcast + p = [0.1, 0.9] + counts = [3., 2] + dist = tf.contrib.distributions.Multinomial(n=5., p=p) + pmf = dist.pmf(counts) + # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000. + self.assertAllClose(81./10000, pmf.eval()) + self.assertEqual((), pmf.get_shape()) + + def testPmfPStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + p = [[0.1, 0.9]] + counts = [[1., 0], [0, 1]] + pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts) + self.assertAllClose([0.1, 0.9], pmf.eval()) + self.assertEqual((2), pmf.get_shape()) + + def testPmfPStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + p = [0.1, 0.9] + counts = [[1., 0], [0, 1]] + pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts) + self.assertAllClose([0.1, 0.9], pmf.eval()) + self.assertEqual((2), pmf.get_shape()) + + def testPmfCountsStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + p = [[0.1, 0.9], [0.7, 0.3]] + counts = [[1., 0]] + pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts) + self.assertAllClose(pmf.eval(), [0.1, 0.7]) + self.assertEqual((2), pmf.get_shape()) + + def testPmfCountsStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + p = [[0.1, 0.9], [0.7, 0.3]] + counts = [1., 0] + pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts) + self.assertAllClose(pmf.eval(), [0.1, 0.7]) + self.assertEqual(pmf.get_shape(), (2)) + + def testPmfShapeCountsStretched_N(self): + with self.test_session(): + # [2, 2, 2] + p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]] + # [2, 2] + n = [[3., 3], [3, 3]] + # [2] + counts = [2., 1] + pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts) + pmf.eval() + self.assertEqual(pmf.get_shape(), (2, 2)) + + def testPmfShapeCountsPStretched_N(self): + with self.test_session(): + p = [0.1, 0.9] + counts = [3., 2] + n = np.full([4, 3], 5., dtype=np.float32) + pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts) + pmf.eval() + self.assertEqual((4, 3), pmf.get_shape()) + + def testMultinomialMean(self): + with self.test_session(): + n = 5. + p = [0.1, 0.2, 0.7] + dist = tf.contrib.distributions.Multinomial(n=n, p=p) + expected_means = 5 * np.array(p, dtype=np.float32) + self.assertEqual((3,), dist.mean().get_shape()) + self.assertAllClose(expected_means, dist.mean().eval()) + + def testMultinomialVariance(self): + with self.test_session(): + n = 5. + p = [0.1, 0.2, 0.7] + dist = tf.contrib.distributions.Multinomial(n=n, p=p) + expected_variances = [ + [9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]] + self.assertEqual((3, 3), dist.variance().get_shape()) + self.assertAllClose(expected_variances, dist.variance().eval()) + + def testMultinomialVariance_batch(self): + with self.test_session(): + # Shape [2] + n = [5.] * 2 + # Shape [4, 1, 2] + p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2 + dist = tf.contrib.distributions.Multinomial(n=n, p=p) + # Shape [2, 2] + inner_var = [[9./20, -9/20], [-9/20, 9/20]] + # Shape [4, 2, 2, 2] + expected_variances = [[inner_var, inner_var]] * 4 + self.assertEqual((4, 2, 2, 2), dist.variance().get_shape()) + self.assertAllClose(expected_variances, dist.variance().eval()) + + def testVariance_multidimensional(self): + # Shape [3, 5, 4] + p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32) + # Shape [6, 3, 3] + p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32) + + ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32) + ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32) + + with self.test_session(): + dist = tf.contrib.distributions.Multinomial(ns, p) + dist2 = tf.contrib.distributions.Multinomial(ns2, p2) + + variance = dist.variance() + variance2 = dist2.variance() + self.assertEqual((3, 5, 4, 4), variance.get_shape()) + self.assertEqual((6, 3, 3, 3), variance2.get_shape()) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py index fe5826e491f..1db599b3fea 100644 --- a/tensorflow/contrib/distributions/python/ops/bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py @@ -19,15 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops @@ -38,10 +36,6 @@ class Bernoulli(distribution.Distribution): The Bernoulli distribution is parameterized by p, the probability of a positive event. - - Note, the following methods of the base class aren't implemented: - * cdf - * log_cdf """ def __init__(self, @@ -64,10 +58,10 @@ class Bernoulli(distribution.Distribution): dtype: dtype for samples. validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args, `log_pmf` may return nans. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: @@ -77,27 +71,8 @@ class Bernoulli(distribution.Distribution): self._name = name self._dtype = dtype self._validate_args = validate_args - check_op = check_ops.assert_less_equal - if p is None and logits is None: - raise ValueError("Must pass p or logits.") - elif p is not None and logits is not None: - raise ValueError("Must pass either p or logits, not both.") - elif p is None: - with ops.op_scope([logits], name): - self._logits = array_ops.identity(logits, name="logits") - with ops.name_scope(name): - with ops.name_scope("p"): - self._p = math_ops.sigmoid(self._logits) - elif logits is None: - with ops.name_scope(name): - with ops.name_scope("p"): - p = array_ops.identity(p) - one = constant_op.constant(1., p.dtype) - zero = constant_op.constant(0., p.dtype) - self._p = control_flow_ops.with_dependencies( - [check_op(p, one), check_op(zero, p)] if validate_args else [], p) - with ops.name_scope("logits"): - self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p) + self._logits, self._p = distribution_util.get_logits_and_prob( + name=name, logits=logits, p=p, validate_args=validate_args) with ops.name_scope(name): with ops.name_scope("q"): self._q = 1. - self._p @@ -184,8 +159,12 @@ class Bernoulli(distribution.Distribution): event = ops.convert_to_tensor(event, name="event") event = math_ops.cast(event, self.logits.dtype) logits = self.logits - if ((event.get_shape().ndims is not None) or - (logits.get_shape().ndims is not None) or + # sigmoid_cross_entropy_with_logits doesn't broadcast shape, + # so we do this here. + # TODO(b/30637701): Check dynamic shape, and don't broadcast if the + # dynamic shapes are the same. + if (not event.get_shape().is_fully_defined() or + not logits.get_shape().is_fully_defined() or event.get_shape() != logits.get_shape()): logits = array_ops.ones_like(event) * logits event = array_ops.ones_like(logits) * event @@ -206,8 +185,7 @@ class Bernoulli(distribution.Distribution): with ops.name_scope(self.name): with ops.op_scope([self.p, n], name): n = ops.convert_to_tensor(n, name="n") - new_shape = array_ops.concat( - 0, [array_ops.expand_dims(n, 0), self.batch_shape()]) + new_shape = array_ops.concat(0, ([n], self.batch_shape())) uniform = random_ops.random_uniform( new_shape, seed=seed, dtype=dtypes.float32) sample = math_ops.less(uniform, self.p) diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py index 2bd64180682..fcf4a9056c3 100644 --- a/tensorflow/contrib/distributions/python/ops/beta.py +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """The Beta distribution class.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -95,6 +96,7 @@ class Beta(distribution.Distribution): x = [.2, .3, .9] dist.pdf(x) # Shape [2] ``` + """ def __init__(self, a, b, validate_args=True, allow_nan_stats=False, @@ -102,20 +104,20 @@ class Beta(distribution.Distribution): """Initialize a batch of Beta distributions. Args: - a: Positive `float` or `double` tensor with shape broadcastable to + a: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Beta distributions. This also defines the dtype of the distribution. - b: Positive `float` or `double` tensor with shape broadcastable to + b: Positive floating point tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Beta distributions. validate_args: Whether to assert valid values for parameters `a` and `b`, - and `x` in `prob` and `log_prob`. If False, correct behavior is not + and `x` in `prob` and `log_prob`. If `False`, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Examples: @@ -127,6 +129,7 @@ class Beta(distribution.Distribution): # Define a 2-batch. dist = Beta([1.0, 2.0], [4.0, 5.0]) ``` + """ with ops.op_scope([a, b], name): with ops.control_dependencies([ @@ -276,8 +279,14 @@ class Beta(distribution.Distribution): array_ops.ones_like(a_b_sum, dtype=self.dtype))) else: return control_flow_ops.with_dependencies([ - check_ops.assert_less(one, a), - check_ops.assert_less(one, b)], mode) + check_ops.assert_less( + one, a, + message="mode not defined for components of a <= 1" + ), + check_ops.assert_less( + one, b, + message="mode not defined for components of b <= 1" + )], mode) def entropy(self, name="entropy"): """Entropy of the distribution in nats.""" @@ -306,7 +315,7 @@ class Beta(distribution.Distribution): """`Log(P[counts])`, computed for every batch member. Args: - x: Non-negative `float` or `double`, tensor whose shape can + x: Non-negative floating point tensor whose shape can be broadcast with `self.a` and `self.b`. For fixed leading dimensions, the last dimension represents counts for the corresponding Beta distribution in `self.a` and `self.b`. `x` is only legal if @@ -334,7 +343,7 @@ class Beta(distribution.Distribution): """`P[x]`, computed for every batch member. Args: - x: Non-negative `float`, `double` tensor whose shape can + x: Non-negative floating point tensor whose shape can be broadcast with `self.a` and `self.b`. For fixed leading dimensions, the last dimension represents x for the corresponding Beta distribution in `self.a` and `self.b`. `x` is only legal if is diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py new file mode 100644 index 00000000000..9978d0ad613 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -0,0 +1,340 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Binomial distribution class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=line-too-long + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + +# pylint: enable=line-too-long + + +class Binomial(distribution.Distribution): + """Binomial distribution. + + This distribution is parameterized by a vector `p` of probabilities and `n`, + the total counts. + + #### Mathematical details + + The Binomial is a distribution over the number of successes in `n` independent + trials, with each trial having the same probability of success `p`. + The probability mass function (pmf): + + ```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)``` + + #### Examples + + Create a single distribution, corresponding to 5 coin flips. + + ```python + dist = Binomial(n=5., p=.5) + ``` + + Create a single distribution (using logits), corresponding to 5 coin flips. + + ```python + dist = Binomial(n=5., logits=0.) + ``` + + Creates 3 distributions with the third distribution most likely to have + successes. + + ```python + p = [.2, .3, .8] + # n will be broadcast to [4., 4., 4.], to match p. + dist = Binomial(n=4., p=p) + ``` + + The distribution functions can be evaluated on counts. + + ```python + # counts same shape as p. + counts = [1., 2, 3] + dist.prob(counts) # Shape [3] + + # p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts. + counts = [[1., 2, 1], [2, 2, 4]] + dist.prob(counts) # Shape [2, 3] + + # p will be broadcast to shape [5, 7, 3] to match counts. + counts = [[...]] # Shape [5, 7, 3] + dist.prob(counts) # Shape [5, 7, 3] + ``` + """ + + def __init__(self, + n, + logits=None, + p=None, + validate_args=True, + allow_nan_stats=False, + name="Binomial"): + """Initialize a batch of Binomial distributions. + + Args: + n: Non-negative floating point tensor with shape broadcastable to + `[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`. + Defines this as a batch of `N1 x ... x Nm` different Binomial + distributions. Its components should be equal to integer values. + logits: Floating point tensor representing the log-odds of a + positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and + the same dtype as `n`. Each entry represents logits for the probability + of success for independent Binomial distributions. + p: Positive floating point tensor with shape broadcastable to + `[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the + probability of success for independent Binomial distributions. + validate_args: Whether to assert valid values for parameters `n` and `p`, + and `x` in `prob` and `log_prob`. If `False`, correct behavior is not + guaranteed. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to prefix Ops created by this distribution class. + + Examples: + + ```python + # Define 1-batch of a binomial distribution. + dist = Binomial(n=2., p=.9) + + # Define a 2-batch. + dist = Binomial(n=[4., 5], p=[.1, .3]) + ``` + + """ + + self._logits, self._p = distribution_util.get_logits_and_prob( + name=name, logits=logits, p=p, validate_args=validate_args) + + with ops.op_scope([n], name): + with ops.control_dependencies([ + check_ops.assert_non_negative( + n, message="n has negative components."), + distribution_util.assert_integer_form( + n, message="n has non-integer components." + )] if validate_args else []): + self._n = array_ops.identity(n, name="convert_n") + + self._name = name + self._validate_args = validate_args + self._allow_nan_stats = allow_nan_stats + + self._mean = self._n * self._p + self._get_batch_shape = self._mean.get_shape() + self._get_event_shape = tensor_shape.TensorShape([]) + + @property + def name(self): + """Name to prepend to all ops.""" + return self._name + + @property + def dtype(self): + """dtype of samples from this distribution.""" + return self._p.dtype + + @property + def validate_args(self): + """Boolean describing behavior on invalid input.""" + return self._validate_args + + @property + def allow_nan_stats(self): + """Boolean describing behavior when a stat is undefined for batch member.""" + return self._allow_nan_stats + + def batch_shape(self, name="batch_shape"): + """Batch dimensions of this instance as a 1-D int32 `Tensor`. + + The product of the dimensions of the `batch_shape` is the number of + independent distributions of this kind the instance represents. + + Args: + name: name to give to the op + + Returns: + `Tensor` `batch_shape` + """ + return array_ops.shape(self._mean) + + def get_batch_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `batch_shape`. May be only partially defined. + + Returns: + batch shape + """ + return self._get_batch_shape + + def event_shape(self, name="event_shape"): + """Shape of a sample from a single distribution as a 1-D int32 `Tensor`. + + Args: + name: name to give to the op + + Returns: + `Tensor` `event_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([], name): + return constant_op.constant([], name=name, dtype=dtypes.int32) + + def get_event_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `event_shape`. May be only partially defined. + + Returns: + event shape + """ + return self._get_event_shape + + @property + def n(self): + """Number of trials.""" + return self._n + + @property + def logits(self): + """Log-odds.""" + return self._logits + + @property + def p(self): + """Probability of success.""" + return self._p + + def mean(self, name="mean"): + """Mean of the distribution.""" + with ops.name_scope(self.name): + return array_ops.identity(self._mean, name=name) + + def variance(self, name="variance"): + """Variance of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._n, self._p], name): + return self._n * self._p * (1 - self._p) + + def std(self, name="std"): + """Standard deviation of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._n, self._p], name): + return math_ops.sqrt(self.variance()) + + def mode(self, name="mode"): + """Mode of the distribution. + + Note that when `(n + 1) * p` is an integer, there are actually two modes. + Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return + only the larger of the two modes. + + Args: + name: The name for this op. + + Returns: + The mode of the Binomial distribution. + """ + with ops.name_scope(self.name): + with ops.op_scope([self._n, self._p], name): + return math_ops.floor((self._n + 1) * self._p) + + def log_prob(self, counts, name="log_prob"): + """`Log(P[counts])`, computed for every batch member. + + For each batch member of counts `k`, `P[counts]` is the probability that + after sampling `n` draws from this Binomial distribution, the number of + successes is `k`. Note that different sequences of draws can result in the + same counts, thus the probability includes a combinatorial coefficient. + + Args: + counts: Non-negative tensor with dtype `dtype` and whose shape can be + broadcast with `self.p` and `self.n`. `counts` is only legal if it is + less than or equal to `n` and its components are equal to integer + values. + name: Name to give this Op, defaults to "log_prob". + + Returns: + Log probabilities for each record, shape `[N1,...,Nm]`. + """ + n = self._n + p = self._p + with ops.name_scope(self.name): + with ops.op_scope([self._n, self._p, counts], name): + counts = self._check_counts(counts) + + prob_prob = counts * math_ops.log(p) + ( + n - counts) * math_ops.log(1 - p) + + combinations = math_ops.lgamma(n + 1) - math_ops.lgamma( + counts + 1) - math_ops.lgamma(n - counts + 1) + log_prob = prob_prob + combinations + return log_prob + + def prob(self, counts, name="prob"): + """`P[counts]`, computed for every batch member. + + + For each batch member of counts `k`, `P[counts]` is the probability that + after sampling `n` draws from this Binomial distribution, the number of + successes is `k`. Note that different sequences of draws can result in the + same counts, thus the probability includes a combinatorial coefficient. + + Args: + counts: Non-negative tensor with dtype `dtype` and whose shape can be + broadcast with `self.p` and `self.n`. `counts` is only legal if it is + less than or equal to `n` and its components are equal to integer + values. + name: Name to give this Op, defaults to "prob". + + Returns: + Probabilities for each record, shape `[N1,...,Nm]`. + """ + return super(Binomial, self).prob(counts, name=name) + + @property + def is_continuous(self): + return False + + @property + def is_reparameterized(self): + return False + + def _check_counts(self, counts): + """Check counts for proper shape, values, then return tensor version.""" + counts = ops.convert_to_tensor(counts, name="counts_before_deps") + if not self.validate_args: + return counts + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + counts, message="counts has negative components."), + check_ops.assert_less_equal( + counts, self._n, message="counts are not less than or equal to n."), + distribution_util.assert_integer_form( + counts, message="counts have non-integer components.")], counts) diff --git a/tensorflow/contrib/distributions/python/ops/categorical.py b/tensorflow/contrib/distributions/python/ops/categorical.py index 64572ed7885..e79a732a0c9 100644 --- a/tensorflow/contrib/distributions/python/ops/categorical.py +++ b/tensorflow/contrib/distributions/python/ops/categorical.py @@ -34,11 +34,6 @@ class Categorical(distribution.Distribution): The categorical distribution is parameterized by the log-probabilities of a set of classes. - - Note, the following methods of the base class aren't implemented: - * mean - * cdf - * log_cdf """ def __init__( @@ -57,10 +52,10 @@ class Categorical(distribution.Distribution): indexes into the classes. dtype: The type of the event samples (default: int32). validate_args: Unused in this distribution. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). """ self._allow_nan_stats = allow_nan_stats @@ -177,8 +172,7 @@ class Categorical(distribution.Distribution): samples = math_ops.cast(samples, self._dtype) ret = array_ops.reshape( array_ops.transpose(samples), - array_ops.concat( - 0, [array_ops.expand_dims(n, 0), self.batch_shape()])) + array_ops.concat(0, ([n], self.batch_shape()))) ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n)) .concatenate(self.get_batch_shape())) return ret diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index 65840373f12..e09ef6324b8 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -42,15 +42,15 @@ class Chi2(gamma.Gamma): """Construct Chi2 distributions with parameter `df`. Args: - df: `float` or `double` tensor, the degrees of freedom of the + df: Floating point tensor, the degrees of freedom of the distribution(s). `df` must contain only positive values. validate_args: Whether to assert that `df > 0`, and that `x > 0` in the - methods `prob(x)` and `log_prob(x)`. If `validate_args` is False + methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False` and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. """ # Even though all stats of chi2 are defined for valid parameters, this is diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py index b4f59d5bd8c..25aee5cf03e 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py @@ -19,9 +19,8 @@ from __future__ import print_function # pylint: disable=line-too-long -import numpy as np - from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops @@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops # pylint: enable=line-too-long -def _assert_close(x, y, data=None, summarize=None, name=None): - if x.dtype.is_integer: - return check_ops.assert_equal( - x, y, data=data, summarize=summarize, name=name) - - with ops.op_scope([x, y, data], name, "assert_close"): - x = ops.convert_to_tensor(x, name="x") - y = ops.convert_to_tensor(y, name="y") - tol = np.finfo(x.dtype.as_numpy_dtype).resolution - if data is None: - data = [ - "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ", - y.name, y - ] - condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol)) - return logging_ops.Assert(condition, data, summarize=summarize) - - class Dirichlet(distribution.Distribution): """Dirichlet distribution. @@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution): x = [.2, .3, .5] dist.prob(x) # Shape [2] ``` + """ def __init__(self, @@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution): """Initialize a batch of Dirichlet distributions. Args: - alpha: Positive `float` or `double` tensor with shape broadcastable to + alpha: Positive floating point tensor with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different `k` class Dirichlet distributions. validate_args: Whether to assert valid values for parameters `alpha` and - `x` in `prob` and `log_prob`. If False, correct behavior is not + `x` in `prob` and `log_prob`. If `False`, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Examples: @@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution): # Define a 2-batch of 3-class distributions. dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) ``` + """ with ops.op_scope([alpha], name): alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps") @@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution): array_ops.ones_like(self._alpha, dtype=self.dtype))) else: return control_flow_ops.with_dependencies([ - check_ops.assert_less(one, self._alpha) + check_ops.assert_less( + one, self._alpha, + message="mode not defined for components of alpha <= 1") ], mode) def entropy(self, name="entropy"): @@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution): """`Log(P[counts])`, computed for every batch member. Args: - x: Non-negative `float` or `double`, tensor whose shape can + x: Non-negative tensor with dtype `dtype` and whose shape can be broadcast with `self.alpha`. For fixed leading dimensions, the last dimension represents counts for the corresponding Dirichlet distribution in `self.alpha`. `x` is only legal if it sums up to one. @@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution): """`P[x]`, computed for every batch member. Args: - x: Non-negative `float`, `double` tensor whose shape can + x: Non-negative tensor with dtype `dtype` and whose shape can be broadcast with `self.alpha`. For fixed leading dimensions, the last dimension represents x for the corresponding Dirichlet distribution in `self.alpha` and `self.beta`. `x` is only legal if it sums up to one. @@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution): x = ops.convert_to_tensor(x, name="x_before_deps") candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1]) one = constant_op.constant(1., self.dtype) - dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one), - _assert_close(one, candidate_one) + dependencies = [check_ops.assert_positive(x), check_ops.assert_less( + x, one, message="x has components greater than or equal to 1"), + distribution_util.assert_close(one, candidate_one) ] if self.validate_args else [] return control_flow_ops.with_dependencies(dependencies, x) diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py index 7c779fff065..67cdd566c67 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py @@ -13,13 +13,15 @@ # limitations under the License. # ============================================================================== """The Dirichlet Multinomial distribution class.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=line-too-long -from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops # pylint: enable=line-too-long -def _assert_integer_form(x): - """Check x for integer components (or floats that are equal to integers).""" - x = ops.convert_to_tensor(x, name='x') - casted_x = math_ops.to_int64(x) - return check_ops.assert_equal(x, math_ops.cast( - math_ops.round(casted_x), x.dtype)) - - -def _log_combinations(n, counts, name='log_combinations'): - """Log number of ways counts could have come in.""" - # First a bit about the number of ways counts could have come in: - # E.g. if counts = [1, 2], then this is 3 choose 2. - # In general, this is (sum counts)! / sum(counts!) - # The sum should be along the last dimension of counts. This is the - # "distribution" dimension. Here n a priori represents the sum of counts. - with ops.op_scope([counts], name): - # To compute factorials, use the fact that Gamma(n + 1) = n! - # Compute two terms, each a sum over counts. Compute each for each - # batch member. - # Log Gamma((sum counts) + 1) = Log((sum counts)!) - total_permutations = math_ops.lgamma(n + 1) - # sum(Log Gamma(counts + 1)) = Log sum(counts!) - counts_factorial = math_ops.lgamma(counts + 1) - redundant_permutations = math_ops.reduce_sum(counts_factorial, - reduction_indices=[-1]) - return total_permutations - redundant_permutations - - class DirichletMultinomial(distribution.Distribution): """DirichletMultinomial mixture distribution. @@ -126,6 +100,7 @@ class DirichletMultinomial(distribution.Distribution): counts = [2, 1, 0] dist.pmf(counts) # Shape [2] ``` + """ # TODO(b/27419586) Change docstring for dtype of alpha once int allowed. @@ -134,26 +109,26 @@ class DirichletMultinomial(distribution.Distribution): alpha, validate_args=True, allow_nan_stats=False, - name='DirichletMultinomial'): + name="DirichletMultinomial"): """Initialize a batch of DirichletMultinomial distributions. Args: - n: Non-negative `float` or `double` tensor, whose dtype is the same as + n: Non-negative floating point tensor, whose dtype is the same as `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Dirichlet - multinomial distributions. Its components should be equal to integral + multinomial distributions. Its components should be equal to integer values. - alpha: Positive `float` or `double` tensor, whose dtype is the same as + alpha: Positive floating point tensor, whose dtype is the same as `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different `k` class Dirichlet multinomial distributions. validate_args: Whether to assert valid values for parameters `alpha` and - `n`, and `x` in `prob` and `log_prob`. If False, correct behavior is + `n`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Examples: @@ -166,6 +141,7 @@ class DirichletMultinomial(distribution.Distribution): # Define a 2-batch of 3-class distributions. dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) ``` + """ self._allow_nan_stats = allow_nan_stats self._validate_args = validate_args @@ -221,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution): """dtype of samples from this distribution.""" return self._alpha.dtype - def mean(self, name='mean'): + def mean(self, name="mean"): """Class means for every batch member.""" alpha = self._alpha alpha_sum = self._alpha_sum @@ -231,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution): mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1) return array_ops.expand_dims(n, -1) * mean_no_n - def variance(self, name='mean'): + def variance(self, name="mean"): """Class variances for every batch member. The variance for each batch member is defined as the following: @@ -273,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution): variance *= array_ops.expand_dims(shared_factor, -1) return variance - def batch_shape(self, name='batch_shape'): + def batch_shape(self, name="batch_shape"): """Batch dimensions of this instance as a 1-D int32 `Tensor`. The product of the dimensions of the `batch_shape` is the number of @@ -299,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution): """ return self._get_batch_shape - def event_shape(self, name='event_shape'): + def event_shape(self, name="event_shape"): """Shape of a sample from a single distribution as a 1-D int32 `Tensor`. Args: @@ -322,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution): """ return self._get_event_shape - def cdf(self, x, name='cdf'): + def cdf(self, x, name="cdf"): raise NotImplementedError( - 'DirichletMultinomial does not have a well-defined cdf.') + "DirichletMultinomial does not have a well-defined cdf.") - def log_cdf(self, x, name='log_cdf'): + def log_cdf(self, x, name="log_cdf"): raise NotImplementedError( - 'DirichletMultinomial does not have a well-defined cdf.') + "DirichletMultinomial does not have a well-defined cdf.") - def log_prob(self, counts, name='log_prob'): + def log_prob(self, counts, name="log_prob"): """`Log(P[counts])`, computed for every batch member. For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability @@ -340,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution): probability includes a combinatorial coefficient. Args: - counts: Non-negative `float` or `double` tensor whose dtype is the same - `self` and whose shape can be broadcast with `self.alpha`. For fixed - leading dimensions, the last dimension represents counts for the - corresponding Dirichlet Multinomial distribution in `self.alpha`. - `counts` is only legal if it sums up to `n` and its components are - equal to integral values. + counts: Non-negative tensor with dtype `dtype` and whose shape can be + broadcast with `self.alpha`. For fixed leading dimensions, the last + dimension represents counts for the corresponding Dirichlet Multinomial + distribution in `self.alpha`. `counts` is only legal if it sums up to + `n` and its components are equal to integer values. name: Name to give this Op, defaults to "log_prob". Returns: @@ -359,20 +334,11 @@ class DirichletMultinomial(distribution.Distribution): ordered_prob = (special_math_ops.lbeta(alpha + counts) - special_math_ops.lbeta(alpha)) - log_prob = ordered_prob + _log_combinations(n, counts) - # If alpha = counts = [[]], ordered_prob carries the right shape, which - # is []. However, since reduce_sum([[]]) = [0], log_combinations = [0], - # which is not correct. Luckily, [] + [0] = [], so the sum is fine, but - # shape must be inferred from ordered_prob. We must also make this - # broadcastable with n, so this is multiplied by n to ensure the shape - # is correctly inferred. - # Note also that tf.constant([]).get_shape() = - # TensorShape([Dimension(0)]) - broadcasted_tensor = ordered_prob * n - log_prob.set_shape(broadcasted_tensor.get_shape()) + log_prob = ordered_prob + distribution_util.log_combinations( + n, counts) return log_prob - def prob(self, counts, name='prob'): + def prob(self, counts, name="prob"): """`P[counts]`, computed for every batch member. For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability @@ -382,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution): probability includes a combinatorial coefficient. Args: - counts: Non-negative `float` or `double` tensor whose dtype is the same - `self` and whose shape can be broadcast with `self.alpha`. For fixed - leading dimensions, the last dimension represents counts for the - corresponding Dirichlet Multinomial distribution in `self.alpha`. - `counts` is only legal if it sums up to `n` and its components are - equal to integral values. + counts: Non-negative tensor with dtype `dtype` and whose shape can be + broadcast with `self.alpha`. For fixed leading dimensions, the last + dimension represents counts for the corresponding Dirichlet Multinomial + distribution in `self.alpha`. `counts` is only legal if it sums up to + `n` and its components are equal to integer values. name: Name to give this Op, defaults to "prob". Returns: @@ -397,18 +362,21 @@ class DirichletMultinomial(distribution.Distribution): def _check_counts(self, counts): """Check counts for proper shape, values, then return tensor version.""" - counts = ops.convert_to_tensor(counts, name='counts') + counts = ops.convert_to_tensor(counts, name="counts") if not self.validate_args: return counts candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1]) return control_flow_ops.with_dependencies([ check_ops.assert_non_negative(counts), - check_ops.assert_equal(self._n, candidate_n), - _assert_integer_form(counts)], counts) + check_ops.assert_equal( + self._n, candidate_n, + message="counts do not sum to n" + ), + distribution_util.assert_integer_form(counts)], counts) def _check_alpha(self, alpha): - alpha = ops.convert_to_tensor(alpha, name='alpha') + alpha = ops.convert_to_tensor(alpha, name="alpha") if not self.validate_args: return alpha return control_flow_ops.with_dependencies( @@ -416,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution): check_ops.assert_positive(alpha)], alpha) def _check_n(self, n): - n = ops.convert_to_tensor(n, name='n') + n = ops.convert_to_tensor(n, name="n") if not self.validate_args: return n return control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(n), _assert_integer_form(n)], n) + [check_ops.assert_non_negative(n), + distribution_util.assert_integer_form(n)], n) @property def is_continuous(self): diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py new file mode 100644 index 00000000000..9c751270032 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -0,0 +1,177 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for probability distributions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops + + +def assert_close( + x, y, data=None, summarize=None, message=None, name="assert_close"): + """Assert that that x and y are within machine epsilon of each other. + + Args: + x: Numeric `Tensor` + y: Numeric `Tensor` + data: The tensors to print out if the condition is `False`. Defaults to + error message and first few entries of `x` and `y`. + summarize: Print this many entries of each tensor. + message: A string to prefix to the default message. + name: A name for this operation (optional). + + Returns: + Op raising `InvalidArgumentError` if |x - y| > machine epsilon. + """ + message = message or "" + x = ops.convert_to_tensor(x, name="x") + y = ops.convert_to_tensor(y, name="y") + + if x.dtype.is_integer: + return check_ops.assert_equal( + x, y, data=data, summarize=summarize, message=message, name=name) + + with ops.op_scope([x, y, data], name, "assert_close"): + tol = np.finfo(x.dtype.as_numpy_dtype).resolution + if data is None: + data = [ + message, + "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ", + y.name, y + ] + condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol)) + return logging_ops.Assert( + condition, data, summarize=summarize) + + +def assert_integer_form( + x, data=None, summarize=None, message=None, name="assert_integer_form"): + """Assert that x has integer components (or floats equal to integers). + + Args: + x: Numeric `Tensor` + data: The tensors to print out if the condition is `False`. Defaults to + error message and first few entries of `x` and `y`. + summarize: Print this many entries of each tensor. + message: A string to prefix to the default message. + name: A name for this operation (optional). + + Returns: + Op raising `InvalidArgumentError` if round(x) != x. + """ + + message = message or "x has non-integer components" + x = ops.convert_to_tensor(x, name="x") + casted_x = math_ops.to_int64(x) + return check_ops.assert_equal( + x, math_ops.cast(math_ops.round(casted_x), x.dtype), + data=data, summarize=summarize, message=message, name=name) + + +def get_logits_and_prob( + logits=None, p=None, multidimensional=False, validate_args=True, name=None): + """Converts logits to probabilities and vice-versa, and returns both. + + Args: + logits: Numeric `Tensor` representing log-odds. + p: Numeric `Tensor` representing probabilities. + multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor, + whether the last dimension represents the probability between k classes. + This will additionally assert that the values in the last dimension + sum to one. If `False`, will instead assert that each value is in + `[0, 1]`. + validate_args: Whether to assert `0 <= p <= 1` if multidimensional is + `False`, otherwise that the last dimension of `p` sums to one. + name: A name for this operation (optional). + + Returns: + Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then + the corresponding entry in the returned logits will be `-Inf` and `Inf` + respectively. + + Raises: + ValueError: if neither `p` nor `logits` were passed in, or both were. + """ + if p is None and logits is None: + raise ValueError("Must pass p or logits.") + elif p is not None and logits is not None: + raise ValueError("Must pass either p or logits, not both.") + elif p is None: + with ops.op_scope([logits], name): + logits = array_ops.identity(logits, name="logits") + with ops.name_scope(name): + with ops.name_scope("p"): + p = math_ops.sigmoid(logits) + elif logits is None: + with ops.name_scope(name): + with ops.name_scope("p"): + p = array_ops.identity(p) + if validate_args: + one = constant_op.constant(1., p.dtype) + dependencies = [check_ops.assert_non_negative(p)] + if multidimensional: + dependencies += [assert_close( + math_ops.reduce_sum(p, reduction_indices=[-1]), + one, message="p does not sum to 1.")] + else: + dependencies += [check_ops.assert_less_equal( + p, one, message="p has components greater than 1.")] + p = control_flow_ops.with_dependencies(dependencies, p) + with ops.name_scope("logits"): + logits = math_ops.log(p) - math_ops.log(1. - p) + return (logits, p) + + +def log_combinations(n, counts, name="log_combinations"): + """Multinomial coefficient. + + Given `n` and `counts`, where `counts` has last dimension `k`, we compute + the multinomial coefficient as: + + ```n! / sum_i n_i!``` + + where `i` runs over all `k` classes. + + Args: + n: Numeric `Tensor` broadcastable with `counts`. This represents `n` + outcomes. + counts: Numeric `Tensor` broadcastable with `n`. This represents counts + in `k` classes, where `k` is the last dimension of the tensor. + name: A name for this operation (optional). + + Returns: + `Tensor` representing the multinomial coefficient between `n` and `counts`. + """ + # First a bit about the number of ways counts could have come in: + # E.g. if counts = [1, 2], then this is 3 choose 2. + # In general, this is (sum counts)! / sum(counts!) + # The sum should be along the last dimension of counts. This is the + # "distribution" dimension. Here n a priori represents the sum of counts. + with ops.op_scope([n, counts], name): + total_permutations = math_ops.lgamma(n + 1) + counts_factorial = math_ops.lgamma(counts + 1) + redundant_permutations = math_ops.reduce_sum(counts_factorial, + reduction_indices=[-1]) + return total_permutations - redundant_permutations diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py index c49b3eeba8d..c1a7eb025ef 100644 --- a/tensorflow/contrib/distributions/python/ops/exponential.py +++ b/tensorflow/contrib/distributions/python/ops/exponential.py @@ -46,15 +46,15 @@ class Exponential(gamma.Gamma): """Construct Exponential distribution with parameter `lam`. Args: - lam: `float` or `double` tensor, the rate of the distribution(s). + lam: Floating point tensor, the rate of the distribution(s). `lam` must contain only positive values. validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the - methods `prob(x)` and `log_prob(x)`. If `validate_args` is False + methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False` and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. """ # Even though all statistics of are defined for valid inputs, this is not @@ -95,8 +95,7 @@ class Exponential(gamma.Gamma): broadcast_shape = self._lam.get_shape() with ops.op_scope([self.lam, n], name, "ExponentialSample"): n = ops.convert_to_tensor(n, name="n") - shape = array_ops.concat( - 0, [array_ops.pack([n]), array_ops.shape(self._lam)]) + shape = array_ops.concat(0, ([n], array_ops.shape(self._lam))) # Sample uniformly-at-random from the open-interval (0, 1). sampled = random_ops.random_uniform( shape, minval=np.nextafter( diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py index 1f733ceda16..6bd93877613 100644 --- a/tensorflow/contrib/distributions/python/ops/gamma.py +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -69,19 +69,19 @@ class Gamma(distribution.Distribution): broadcasting (e.g. `alpha + beta` is a valid operation). Args: - alpha: `float` or `double` tensor, the shape params of the + alpha: Floating point tensor, the shape params of the distribution(s). alpha must contain only positive values. - beta: `float` or `double` tensor, the inverse scale params of the + beta: Floating point tensor, the inverse scale params of the distribution(s). beta must contain only positive values. validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in - the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False + the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False` and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. Raises: @@ -213,9 +213,12 @@ class Gamma(distribution.Distribution): nan = np.nan * self._ones() return math_ops.select(alpha_ge_1, mode_if_defined, nan) else: - one = ops.convert_to_tensor(1.0, dtype=self.dtype) + one = constant_op.constant(1.0, dtype=self.dtype) return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, alpha)], mode_if_defined) + [check_ops.assert_less( + one, alpha, + message="mode not defined for components of alpha <= 1" + )], mode_if_defined) def variance(self, name="variance"): """Variance of each batch member.""" diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index a23f6df5717..d78e82a7524 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution): broadcasting (e.g. `alpha + beta` is a valid operation). Args: - alpha: `float` or `double` tensor, the shape params of the + alpha: Floating point tensor, the shape params of the distribution(s). alpha must contain only positive values. - beta: `float` or `double` tensor, the scale params of the distribution(s). + beta: Floating point tensor, the scale params of the distribution(s). beta must contain only positive values. validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in - the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False + the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False` and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. Raises: @@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution): nan = np.nan * self._ones() return math_ops.select(alpha_gt_1, mean_if_defined, nan) else: - one = ops.convert_to_tensor(1.0, dtype=self.dtype) + one = constant_op.constant(1.0, dtype=self.dtype) return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, alpha)], mean_if_defined) + [check_ops.assert_less( + one, alpha, + message="mean not defined for components of alpha <= 1")], + mean_if_defined) def mode(self, name="mode"): """Mode of each batch member. @@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution): nan = np.nan * self._ones() return math_ops.select(alpha_gt_2, var_if_defined, nan) else: - two = ops.convert_to_tensor(2.0, dtype=self.dtype) + two = constant_op.constant(2.0, dtype=self.dtype) return control_flow_ops.with_dependencies( - [check_ops.assert_less(two, alpha)], var_if_defined) + [check_ops.assert_less( + two, alpha, + message="variance not defined for components of alpha <= 2")], + var_if_defined) def log_prob(self, x, name="log_prob"): """Log prob of observations in `x` under these InverseGamma distribution(s). diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py index c134ca2cbfd..c1e0b2d2398 100644 --- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py +++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py @@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None): Args: dist_a: instance of distributions.Distribution. dist_b: instance of distributions.Distribution. - allow_nan: If False (default), a runtime error is raised + allow_nan: If `False` (default), a runtime error is raised if the KL returns NaN values for any batch entry of the given - distributions. If True, the KL may return a NaN for the given entry. + distributions. If `True`, the KL may return a NaN for the given entry. name: (optional) Name scope to use for created operations. Returns: diff --git a/tensorflow/contrib/distributions/python/ops/laplace.py b/tensorflow/contrib/distributions/python/ops/laplace.py index ee6aa81c0f4..a03a80d4ece 100644 --- a/tensorflow/contrib/distributions/python/ops/laplace.py +++ b/tensorflow/contrib/distributions/python/ops/laplace.py @@ -60,17 +60,17 @@ class Laplace(distribution.Distribution): broadcasting (e.g., `loc / scale` is a valid operation). Args: - loc: `float` or `double` tensor which characterizes the location (center) + loc: Floating point tensor which characterizes the location (center) of the distribution. - scale: `float` or `double`, positive-valued tensor which characterzes the - spread of the distribution. + scale: Positive floating point tensor which characterizes the spread of + the distribution. validate_args: Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: @@ -294,8 +294,7 @@ class Laplace(distribution.Distribution): with ops.op_scope([self._loc, self._scale, n], name): n = ops.convert_to_tensor(n) n_val = tensor_util.constant_value(n) - shape = array_ops.concat( - 0, [array_ops.pack([n]), self.batch_shape()]) + shape = array_ops.concat(0, ([n], self.batch_shape())) # Sample uniformly-at-random from the open-interval (-1, 1). uniform_samples = random_ops.random_uniform( shape=shape, diff --git a/tensorflow/contrib/distributions/python/ops/multinomial.py b/tensorflow/contrib/distributions/python/ops/multinomial.py new file mode 100644 index 00000000000..477dd06673e --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/multinomial.py @@ -0,0 +1,343 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Multinomial distribution class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=line-too-long + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + +# pylint: enable=line-too-long + + +class Multinomial(distribution.Distribution): + """Multinomial distribution. + + This distribution is parameterized by a vector `p` of probability + parameters for `k` classes and `n`, the counts per each class.. + + #### Mathematical details + + The Multinomial is a distribution over k-class count data, meaning + for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a + probability of these draws being made from the distribution. The distribution + has hyperparameters `p = (p_1,...,p_k)`, and probability mass + function (pmf): + + ```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k``` + + where above `n = sum_j n_j`, `n!` is `n` factorial. + + #### Examples + + Create a 3-class distribution, with the 3rd class is most likely to be drawn, + using logits.. + + ```python + logits = [-50., -43, 0] + dist = Multinomial(n=4., logits=logits) + ``` + + Create a 3-class distribution, with the 3rd class is most likely to be drawn. + + ```python + p = [.2, .3, .5] + dist = Multinomial(n=4., p=p) + ``` + + The distribution functions can be evaluated on counts. + + ```python + # counts same shape as p. + counts = [1., 0, 3] + dist.prob(counts) # Shape [] + + # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts. + counts = [[1., 2, 1], [2, 2, 0]] + dist.prob(counts) # Shape [2] + + # p will be broadcast to shape [5, 7, 3] to match counts. + counts = [[...]] # Shape [5, 7, 3] + dist.prob(counts) # Shape [5, 7] + ``` + + Create a 2-batch of 3-class distributions. + + ```python + p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3] + dist = Multinomial(n=[4., 5], p=p) + + counts = [[2., 1, 1], [3, 1, 1]] + dist.prob(counts) # Shape [2] + ``` + """ + + def __init__(self, + n, + logits=None, + p=None, + validate_args=True, + allow_nan_stats=False, + name="Multinomial"): + """Initialize a batch of Multinomial distributions. + + Args: + n: Non-negative floating point tensor with shape broadcastable to + `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of + `N1 x ... x Nm` different Multinomial distributions. Its components + should be equal to integer values. + logits: Floating point tensor representing the log-odds of a + positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`, + and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm` + different `k` class Multinomial distributions. + p: Positive floating point tensor with shape broadcastable to + `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as + a batch of `N1 x ... x Nm` different `k` class Multinomial + distributions. `p`'s components in the last portion of its shape should + sum up to 1. + validate_args: Whether to assert valid values for parameters `n` and `p`, + and `x` in `prob` and `log_prob`. If `False`, correct behavior is not + guaranteed. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to prefix Ops created by this distribution class. + + Examples: + + ```python + # Define 1-batch of 2-class multinomial distribution, + # also known as a Binomial distribution. + dist = Multinomial(n=2., p=[.1, .9]) + + # Define a 2-batch of 3-class distributions. + dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]]) + ``` + + """ + + self._logits, self._p = distribution_util.get_logits_and_prob( + name=name, logits=logits, p=p, validate_args=validate_args, + multidimensional=True) + with ops.op_scope([n, self._p], name): + with ops.control_dependencies([ + check_ops.assert_non_negative( + n, message="n has negative components."), + distribution_util.assert_integer_form( + n, message="n has non-integer components." + )] if validate_args else []): + self._n = array_ops.identity(n, name="convert_n") + self._name = name + + self._validate_args = validate_args + self._allow_nan_stats = allow_nan_stats + + self._mean = array_ops.expand_dims(n, -1) * self._p + # Only used for inferring shape. + self._broadcast_shape = math_ops.reduce_sum(self._mean, + reduction_indices=[-1], + keep_dims=False) + + self._get_batch_shape = self._broadcast_shape.get_shape() + self._get_event_shape = ( + self._mean.get_shape().with_rank_at_least(1)[-1:]) + + @property + def n(self): + """Number of trials.""" + return self._n + + @property + def p(self): + """Event probabilities.""" + return self._p + + @property + def logits(self): + """Log-odds.""" + return self._logits + + @property + def name(self): + """Name to prepend to all ops.""" + return self._name + + @property + def dtype(self): + """dtype of samples from this distribution.""" + return self._p.dtype + + @property + def validate_args(self): + """Boolean describing behavior on invalid input.""" + return self._validate_args + + @property + def allow_nan_stats(self): + """Boolean describing behavior when a stat is undefined for batch member.""" + return self._allow_nan_stats + + def batch_shape(self, name="batch_shape"): + """Batch dimensions of this instance as a 1-D int32 `Tensor`. + + The product of the dimensions of the `batch_shape` is the number of + independent distributions of this kind the instance represents. + + Args: + name: name to give to the op + + Returns: + `Tensor` `batch_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([self._broadcast_shape], name): + return array_ops.shape(self._broadcast_shape) + + def get_batch_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `batch_shape`. May be only partially defined. + + Returns: + batch shape + """ + return self._get_batch_shape + + def event_shape(self, name="event_shape"): + """Shape of a sample from a single distribution as a 1-D int32 `Tensor`. + + Args: + name: name to give to the op + + Returns: + `Tensor` `event_shape` + """ + with ops.name_scope(self.name): + with ops.op_scope([self._mean], name): + return array_ops.gather(array_ops.shape(self._mean), + [array_ops.rank(self._mean) - 1]) + + def get_event_shape(self): + """`TensorShape` available at graph construction time. + + Same meaning as `event_shape`. May be only partially defined. + + Returns: + event shape + """ + return self._get_event_shape + + def mean(self, name="mean"): + """Mean of the distribution.""" + with ops.name_scope(self.name): + return array_ops.identity(self._mean, name=name) + + def variance(self, name="variance"): + """Variance of the distribution.""" + with ops.name_scope(self.name): + with ops.op_scope([self._n, self._p, self._mean], name): + p = array_ops.expand_dims( + self._p * array_ops.expand_dims( + array_ops.ones_like(self._n), -1), -1) + variance = -math_ops.batch_matmul( + array_ops.expand_dims(self._mean, -1), p, adj_y=True) + variance += array_ops.batch_matrix_diag(self._mean) + return variance + + def log_prob(self, counts, name="log_prob"): + """`Log(P[counts])`, computed for every batch member. + + For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability + that after sampling `n` draws from this Multinomial distribution, the + number of draws falling in class `j` is `n_j`. Note that different + sequences of draws can result in the same counts, thus the probability + includes a combinatorial coefficient. + + Args: + counts: Non-negative tensor with dtype `dtype` and whose shape can + be broadcast with `self.p` and `self.n`. For fixed leading dimensions, + the last dimension represents counts for the corresponding Multinomial + distribution in `self.p`. `counts` is only legal if it sums up to `n` + and its components are equal to integer values. + name: Name to give this Op, defaults to "log_prob". + + Returns: + Log probabilities for each record, shape `[N1,...,Nm]`. + """ + n = self._n + p = self._p + with ops.name_scope(self.name): + with ops.op_scope([n, p, counts], name): + counts = self._check_counts(counts) + + prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p), + reduction_indices=[-1]) + log_prob = prob_prob + distribution_util.log_combinations( + n, counts) + return log_prob + + def prob(self, counts, name="prob"): + """`P[counts]`, computed for every batch member. + + For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability + that after sampling `n` draws from this Multinomial distribution, the + number of draws falling in class `j` is `n_j`. Note that different + sequences of draws can result in the same counts, thus the probability + includes a combinatorial coefficient. + + Args: + counts: Non-negative tensor with dtype `dtype` and whose shape can + be broadcast with `self.p` and `self.n`. For fixed leading dimensions, + the last dimension represents counts for the corresponding Multinomial + distribution in `self.p`. `counts` is only legal if it sums up to `n` + and its components are equal to integer values. + name: Name to give this Op, defaults to "prob". + + Returns: + Probabilities for each record, shape `[N1,...,Nm]`. + """ + return super(Multinomial, self).prob(counts, name=name) + + @property + def is_continuous(self): + return False + + @property + def is_reparameterized(self): + return False + + def _check_counts(self, counts): + """Check counts for proper shape, values, then return tensor version.""" + counts = ops.convert_to_tensor(counts, name="counts_before_deps") + candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1]) + if not self.validate_args: + return counts + + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + counts, message="counts has negative components."), + check_ops.assert_equal( + self._n, candidate_n, message="counts do not sum to n."), + distribution_util.assert_integer_form( + counts, message="counts have non-integer components.")], counts) diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py index dafddc0faac..8936594dfac 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -105,9 +105,9 @@ class MultivariateNormalOperatorPD(distribution.Distribution): which determines the covariance. Args: - mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`. - cov: `float` or `double` instance of `OperatorPDBase` with same `dtype` - as `mu` and shape `[N1,...,Nb, k, k]`. + mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`. + cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape + `[N1,...,Nb, k, k]`. validate_args: Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. @@ -466,7 +466,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD): The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`. Args: - mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`, + mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`. diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`, representing the standard deviations. Must be positive. @@ -581,13 +581,13 @@ class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD): ``` Args: - mu: Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`, + mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`, `n >= 0`. The means. - diag_large: Optional rank `n + 1` `float` or `double` tensor, shape + diag_large: Optional rank `n + 1` floating point tensor, shape `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`. - v: Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]` + v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]` `n >= 0`. Defines the matrix `V`. - diag_small: Rank `n + 1` `float` or `double` tensor, shape + diag_small: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default is `None`, which means `D` will be the identity matrix. validate_args: Whether to validate input with asserts. If `validate_args` @@ -670,7 +670,7 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD): factors, such that the covariance of each batch member is `chol chol^T`. Args: - mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, + mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`. chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape `[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as @@ -750,7 +750,7 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD): User must provide means `mu` and `sigma`, the mean and covariance. Args: - mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, + mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`. sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape `[N1,...,Nb, k, k]`. Each batch member must be positive definite. diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/contrib/distributions/python/ops/normal.py index dff8c7fdbbe..182afa31f7f 100644 --- a/tensorflow/contrib/distributions/python/ops/normal.py +++ b/tensorflow/contrib/distributions/python/ops/normal.py @@ -92,15 +92,15 @@ class Normal(distribution.Distribution): broadcasting (e.g. `mu + sigma` is a valid operation). Args: - mu: `float` or `double` tensor, the means of the distribution(s). - sigma: `float` or `double` tensor, the stddevs of the distribution(s). + mu: Floating point tensor, the means of the distribution(s). + sigma: Floating point tensor, the stddevs of the distribution(s). sigma must contain only positive values. validate_args: Whether to assert that `sigma > 0`. If `validate_args` is - False, correct output is not guaranteed when input is invalid. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + `False`, correct output is not guaranteed when input is invalid. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: @@ -321,8 +321,7 @@ class Normal(distribution.Distribution): with ops.op_scope([self._mu, self._sigma, n], name): broadcast_shape = (self._mu + self._sigma).get_shape() n = ops.convert_to_tensor(n) - shape = array_ops.concat( - 0, [array_ops.pack([n]), array_ops.shape(self.mean())]) + shape = array_ops.concat(0, ([n], array_ops.shape(self.mean()))) sampled = random_ops.random_normal( shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py index e5fa624ddc4..8e43c95b6db 100644 --- a/tensorflow/contrib/distributions/python/ops/student_t.py +++ b/tensorflow/contrib/distributions/python/ops/student_t.py @@ -82,6 +82,7 @@ class StudentT(distribution.Distribution): # returning a length 2 tensor. dist.pdf(3.0) ``` + """ def __init__(self, @@ -99,19 +100,19 @@ class StudentT(distribution.Distribution): broadcasting (e.g. `df + mu + sigma` is a valid operation). Args: - df: `float` or `double` tensor, the degrees of freedom of the + df: Floating point tensor, the degrees of freedom of the distribution(s). `df` must contain only positive values. - mu: `float` or `double` tensor, the means of the distribution(s). - sigma: `float` or `double` tensor, the scaling factor for the + mu: Floating point tensor, the means of the distribution(s). + sigma: Floating point tensor, the scaling factor for the distribution(s). `sigma` must contain only positive values. Note that `sigma` is not the standard deviation of this distribution. validate_args: Whether to assert that `df > 0, sigma > 0`. If - `validate_args` is False and inputs are invalid, correct behavior is not - guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + `validate_args` is `False` and inputs are invalid, correct behavior is + not guaranteed. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: @@ -185,9 +186,12 @@ class StudentT(distribution.Distribution): nan = np.nan + self._zeros() return math_ops.select(df_gt_1, result_if_defined, nan) else: - one = ops.convert_to_tensor(1.0, dtype=self.dtype) + one = constant_op.constant(1.0, dtype=self.dtype) return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, self._df)], result_if_defined) + [check_ops.assert_less( + one, self._df, + message="mean not defined for components of df <= 1" + )], result_if_defined) def mode(self, name="mode"): with ops.name_scope(self.name): @@ -232,9 +236,12 @@ class StudentT(distribution.Distribution): result_where_defined, self._zeros() + np.nan) else: - one = ops.convert_to_tensor(1.0, self.dtype) + one = constant_op.constant(1.0, dtype=self.dtype) return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, self._df)], result_where_defined) + [check_ops.assert_less( + one, self._df, + message="variance not defined for components of df <= 1" + )], result_where_defined) def std(self, name="std"): with ops.name_scope(self.name): @@ -348,8 +355,7 @@ class StudentT(distribution.Distribution): # Let X = R*cos(theta), and let Y = R*sin(theta). # Then X ~ t_df and Y ~ t_df. # The variates X and Y are not independent. - shape = array_ops.concat(0, [array_ops.pack([2, n]), - self.batch_shape()]) + shape = array_ops.concat(0, ([2, n], self.batch_shape())) uniform = random_ops.random_uniform(shape=shape, dtype=self.dtype, seed=seed) diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index 185741b2176..82971301560 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution): name="LogitNormalTransformedDistribution" ) ``` + """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/uniform.py b/tensorflow/contrib/distributions/python/ops/uniform.py index eb196a3ea91..09437d36d16 100644 --- a/tensorflow/contrib/distributions/python/ops/uniform.py +++ b/tensorflow/contrib/distributions/python/ops/uniform.py @@ -67,14 +67,14 @@ class Uniform(distribution.Distribution): ``` Args: - a: `float` or `double` tensor, the minimum endpoint. - b: `float` or `double` tensor, the maximum endpoint. Must be > `a`. - validate_args: Whether to assert that `a > b`. If `validate_args` is False - and inputs are invalid, correct behavior is not guaranteed. - allow_nan_stats: Boolean, default False. If False, raise an exception if - a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If True, batch members with valid parameters leading to undefined - statistics will return NaN for this statistic. + a: Floating point tensor, the minimum endpoint. + b: Floating point tensor, the maximum endpoint. Must be > `a`. + validate_args: Whether to assert that `a > b`. If `validate_args` is + `False` and inputs are invalid, correct behavior is not guaranteed. + allow_nan_stats: Boolean, default `False`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Raises: @@ -83,8 +83,9 @@ class Uniform(distribution.Distribution): self._allow_nan_stats = allow_nan_stats self._validate_args = validate_args with ops.op_scope([a, b], name): - with ops.control_dependencies([check_ops.assert_less(a, b)] if - validate_args else []): + with ops.control_dependencies([check_ops.assert_less( + a, b, message="uniform not defined when a > b.")] if validate_args + else []): a = array_ops.identity(a, name="a") b = array_ops.identity(b, name="b") @@ -228,7 +229,7 @@ class Uniform(distribution.Distribution): n = ops.convert_to_tensor(n, name="n") n_val = tensor_util.constant_value(n) - shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()]) + shape = array_ops.concat(0, ([n], self.batch_shape())) samples = random_ops.random_uniform(shape=shape, dtype=self.dtype, seed=seed)