From 606fbb46eb20c795eacd9bec056df062c6760792 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 May 2016 12:41:42 -0800 Subject: [PATCH] BUGFIX: Call n = convert_to_tensor(n). Exponential.sample forgot to call n = convert_to_tensor(n), and tensor_util.constant_value(n) only works with n a Tensor. Change: 123248349 --- .../distributions/python/kernel_tests/exponential_test.py | 5 ++--- tensorflow/contrib/distributions/python/ops/exponential.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py index 5e3fed1ed80..6fd03e90bf6 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py @@ -105,10 +105,9 @@ class ExponentialTest(tf.test.TestCase): exponential = tf.contrib.distributions.Exponential(lam=lam) - n_v = 100000 - n = tf.constant(n_v) + n = 100000 samples = exponential.sample(n, seed=138) - self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) + self.assertEqual(samples.get_shape(), (n, batch_size, 2)) sample_values = samples.eval() diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py index b80632fc496..4a93c210b91 100644 --- a/tensorflow/contrib/distributions/python/ops/exponential.py +++ b/tensorflow/contrib/distributions/python/ops/exponential.py @@ -70,6 +70,7 @@ class Exponential(gamma.Gamma): """ broadcast_shape = self._lam.get_shape() with ops.op_scope([self.lam, n], name, "ExponentialSample"): + n = ops.convert_to_tensor(n, name="n") shape = array_ops.concat( 0, [array_ops.pack([n]), array_ops.shape(self._lam)]) sampled = random_ops.random_uniform(