Adding test for tf.cond with a mutable boolean pred.

PiperOrigin-RevId: 167844952
This commit is contained in:
A. Unique TensorFlower 2017-09-07 04:45:15 -07:00 committed by TensorFlower Gardener
parent 33acf5ba4f
commit 05c9966eda

View File

@ -362,6 +362,20 @@ class CondTest(TensorFlowTestCase):
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondModifyBoolPred(self):
# This test in particular used to fail only when running in GPU, hence
# use_gpu=True.
with self.test_session(use_gpu=True) as sess:
bool_var = variable_scope.get_variable("bool_var", dtype=dtypes.bool,
initializer=True)
cond_on_bool_var = control_flow_ops.cond(
pred=bool_var,
true_fn=lambda: state_ops.assign(bool_var, False),
false_fn=lambda: True)
sess.run(bool_var.initializer)
self.assertEquals(sess.run(cond_on_bool_var), False)
self.assertEquals(sess.run(cond_on_bool_var), True)
def testCondMissingArg1(self):
with self.test_session():
x = constant_op.constant(1)