Bugfix: Make tf.contrib.distributions.Independent
tests not flaky.
PiperOrigin-RevId: 173921378
This commit is contained in:
parent
4b63f47d9f
commit
629e6d0c10
@ -56,7 +56,7 @@ class ProductDistributionTest(test.TestCase):
|
|||||||
distribution=normal_lib.Normal(loc=loc, scale=scale),
|
distribution=normal_lib.Normal(loc=loc, scale=scale),
|
||||||
reinterpreted_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
x = ind.sample([4, 5])
|
x = ind.sample([4, 5], seed=42)
|
||||||
log_prob_x = ind.log_prob(x)
|
log_prob_x = ind.log_prob(x)
|
||||||
x_, actual_log_prob_x = sess.run([x, log_prob_x])
|
x_, actual_log_prob_x = sess.run([x, log_prob_x])
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ class ProductDistributionTest(test.TestCase):
|
|||||||
scale_identity_multiplier=scale),
|
scale_identity_multiplier=scale),
|
||||||
reinterpreted_batch_ndims=1)
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
x = ind.sample([4, 5])
|
x = ind.sample([4, 5], seed=42)
|
||||||
log_prob_x = ind.log_prob(x)
|
log_prob_x = ind.log_prob(x)
|
||||||
x_, actual_log_prob_x = sess.run([x, log_prob_x])
|
x_, actual_log_prob_x = sess.run([x, log_prob_x])
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class ProductDistributionTest(test.TestCase):
|
|||||||
dtypes.float32, shape=logits.shape if static_shape else None)
|
dtypes.float32, shape=logits.shape if static_shape else None)
|
||||||
ind = independent_lib.Independent(
|
ind = independent_lib.Independent(
|
||||||
distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
|
distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
|
||||||
x = ind.sample(sample_shape)
|
x = ind.sample(sample_shape, seed=42)
|
||||||
log_prob_x = ind.log_prob(x)
|
log_prob_x = ind.log_prob(x)
|
||||||
[
|
[
|
||||||
x_,
|
x_,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user