Fix a bug that an erroneous control edge can be introduced when loops are nested in control dependency context.
PiperOrigin-RevId: 157616919
This commit is contained in:
parent
2de94bbb89
commit
964d1a5090
@ -1085,6 +1085,27 @@ class ControlFlowTest(test.TestCase):
|
|||||||
(constant_op.constant(5),))
|
(constant_op.constant(5),))
|
||||||
self.assertEqual(0, sess.run(loop))
|
self.assertEqual(0, sess.run(loop))
|
||||||
|
|
||||||
|
def testWhileCondWithControl_1(self):
|
||||||
|
with self.test_session():
|
||||||
|
v = variable_scope.get_variable(
|
||||||
|
"v", [], initializer=init_ops.constant_initializer(2))
|
||||||
|
i0 = constant_op.constant(0)
|
||||||
|
with ops.control_dependencies([i0]):
|
||||||
|
def loop_condition(i):
|
||||||
|
return i < 4
|
||||||
|
|
||||||
|
def loop_body(i):
|
||||||
|
some_cond = control_flow_ops.cond(
|
||||||
|
constant_op.constant(True),
|
||||||
|
lambda: state_ops.assign(v, math_ops.square(v)),
|
||||||
|
lambda: v)
|
||||||
|
with ops.control_dependencies([some_cond]):
|
||||||
|
return i + 1
|
||||||
|
r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertEqual(4, r.eval())
|
||||||
|
self.assertAllClose(65536.0, v.eval())
|
||||||
|
|
||||||
def testWhileCondExitControl(self):
|
def testWhileCondExitControl(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
v = variables.Variable(1)
|
v = variables.Variable(1)
|
||||||
|
@ -1634,6 +1634,8 @@ class CondContext(ControlFlowContext):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
op._update_input(index, real_x)
|
op._update_input(index, real_x)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
# Remove any external control dependency on this op.
|
||||||
|
self._RemoveExternalControlEdges(op)
|
||||||
for x in op.outputs:
|
for x in op.outputs:
|
||||||
self._values.add(x.name)
|
self._values.add(x.name)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user