BUGFIX: Call n = convert_to_tensor(n).
Exponential.sample forgot to call n = convert_to_tensor(n), and tensor_util.constant_value(n) only works with n a Tensor. Change: 123248349
This commit is contained in:
parent
8074e98b20
commit
606fbb46eb
@ -105,10 +105,9 @@ class ExponentialTest(tf.test.TestCase):
|
||||
|
||||
exponential = tf.contrib.distributions.Exponential(lam=lam)
|
||||
|
||||
n_v = 100000
|
||||
n = tf.constant(n_v)
|
||||
n = 100000
|
||||
samples = exponential.sample(n, seed=138)
|
||||
self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
|
||||
self.assertEqual(samples.get_shape(), (n, batch_size, 2))
|
||||
|
||||
sample_values = samples.eval()
|
||||
|
||||
|
@ -70,6 +70,7 @@ class Exponential(gamma.Gamma):
|
||||
"""
|
||||
broadcast_shape = self._lam.get_shape()
|
||||
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
|
||||
n = ops.convert_to_tensor(n, name="n")
|
||||
shape = array_ops.concat(
|
||||
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
|
||||
sampled = random_ops.random_uniform(
|
||||
|
Loading…
Reference in New Issue
Block a user