BUGFIX: Call n = convert_to_tensor(n).

Exponential.sample forgot to call n = convert_to_tensor(n), and
tensor_util.constant_value(n) only works with n a Tensor.
Change: 123248349
This commit is contained in:
A. Unique TensorFlower 2016-05-25 12:41:42 -08:00 committed by TensorFlower Gardener
parent 8074e98b20
commit 606fbb46eb
2 changed files with 3 additions and 3 deletions

View File

@ -105,10 +105,9 @@ class ExponentialTest(tf.test.TestCase):
exponential = tf.contrib.distributions.Exponential(lam=lam)
n_v = 100000
n = tf.constant(n_v)
n = 100000
samples = exponential.sample(n, seed=138)
self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
self.assertEqual(samples.get_shape(), (n, batch_size, 2))
sample_values = samples.eval()

View File

@ -70,6 +70,7 @@ class Exponential(gamma.Gamma):
"""
broadcast_shape = self._lam.get_shape()
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
n = ops.convert_to_tensor(n, name="n")
shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
sampled = random_ops.random_uniform(