Fix a bug that an erroneous control edge can be introduced when loops are nested in control dependency context.

PiperOrigin-RevId: 157616919
This commit is contained in:
Yuan Yu 2017-05-31 12:09:27 -07:00 committed by TensorFlower Gardener
parent 2de94bbb89
commit 964d1a5090
2 changed files with 23 additions and 0 deletions

View File

@ -1085,6 +1085,27 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
def testWhileCondWithControl_1(self):
with self.test_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
i0 = constant_op.constant(0)
with ops.control_dependencies([i0]):
def loop_condition(i):
return i < 4
def loop_body(i):
some_cond = control_flow_ops.cond(
constant_op.constant(True),
lambda: state_ops.assign(v, math_ops.square(v)),
lambda: v)
with ops.control_dependencies([some_cond]):
return i + 1
r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
variables.global_variables_initializer().run()
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
def testWhileCondExitControl(self):
with self.test_session():
v = variables.Variable(1)

View File

@ -1634,6 +1634,8 @@ class CondContext(ControlFlowContext):
# pylint: disable=protected-access
op._update_input(index, real_x)
# pylint: enable=protected-access
# Remove any external control dependency on this op.
self._RemoveExternalControlEdges(op)
for x in op.outputs:
self._values.add(x.name)
# pylint: disable=protected-access