Fix a bug that an erroneous control edge can be introduced when loops are nested in control dependency context.
PiperOrigin-RevId: 157616919
This commit is contained in:
parent
2de94bbb89
commit
964d1a5090
@ -1085,6 +1085,27 @@ class ControlFlowTest(test.TestCase):
|
||||
(constant_op.constant(5),))
|
||||
self.assertEqual(0, sess.run(loop))
|
||||
|
||||
def testWhileCondWithControl_1(self):
|
||||
with self.test_session():
|
||||
v = variable_scope.get_variable(
|
||||
"v", [], initializer=init_ops.constant_initializer(2))
|
||||
i0 = constant_op.constant(0)
|
||||
with ops.control_dependencies([i0]):
|
||||
def loop_condition(i):
|
||||
return i < 4
|
||||
|
||||
def loop_body(i):
|
||||
some_cond = control_flow_ops.cond(
|
||||
constant_op.constant(True),
|
||||
lambda: state_ops.assign(v, math_ops.square(v)),
|
||||
lambda: v)
|
||||
with ops.control_dependencies([some_cond]):
|
||||
return i + 1
|
||||
r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
|
||||
variables.global_variables_initializer().run()
|
||||
self.assertEqual(4, r.eval())
|
||||
self.assertAllClose(65536.0, v.eval())
|
||||
|
||||
def testWhileCondExitControl(self):
|
||||
with self.test_session():
|
||||
v = variables.Variable(1)
|
||||
|
@ -1634,6 +1634,8 @@ class CondContext(ControlFlowContext):
|
||||
# pylint: disable=protected-access
|
||||
op._update_input(index, real_x)
|
||||
# pylint: enable=protected-access
|
||||
# Remove any external control dependency on this op.
|
||||
self._RemoveExternalControlEdges(op)
|
||||
for x in op.outputs:
|
||||
self._values.add(x.name)
|
||||
# pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user