Adding test for tf.cond with a mutable boolean pred.
PiperOrigin-RevId: 167844952
This commit is contained in:
parent
33acf5ba4f
commit
05c9966eda
@ -362,6 +362,20 @@ class CondTest(TensorFlowTestCase):
|
||||
fn2=lambda: math_ops.add(y, 23))
|
||||
self.assertEquals(z.eval(), 24)
|
||||
|
||||
def testCondModifyBoolPred(self):
|
||||
# This test in particular used to fail only when running in GPU, hence
|
||||
# use_gpu=True.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
bool_var = variable_scope.get_variable("bool_var", dtype=dtypes.bool,
|
||||
initializer=True)
|
||||
cond_on_bool_var = control_flow_ops.cond(
|
||||
pred=bool_var,
|
||||
true_fn=lambda: state_ops.assign(bool_var, False),
|
||||
false_fn=lambda: True)
|
||||
sess.run(bool_var.initializer)
|
||||
self.assertEquals(sess.run(cond_on_bool_var), False)
|
||||
self.assertEquals(sess.run(cond_on_bool_var), True)
|
||||
|
||||
def testCondMissingArg1(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user