diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index d4ab4ca7aa4..77982654bd3 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1085,6 +1085,27 @@ class ControlFlowTest(test.TestCase): (constant_op.constant(5),)) self.assertEqual(0, sess.run(loop)) + def testWhileCondWithControl_1(self): + with self.test_session(): + v = variable_scope.get_variable( + "v", [], initializer=init_ops.constant_initializer(2)) + i0 = constant_op.constant(0) + with ops.control_dependencies([i0]): + def loop_condition(i): + return i < 4 + + def loop_body(i): + some_cond = control_flow_ops.cond( + constant_op.constant(True), + lambda: state_ops.assign(v, math_ops.square(v)), + lambda: v) + with ops.control_dependencies([some_cond]): + return i + 1 + r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,)) + variables.global_variables_initializer().run() + self.assertEqual(4, r.eval()) + self.assertAllClose(65536.0, v.eval()) + def testWhileCondExitControl(self): with self.test_session(): v = variables.Variable(1) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 96ace6e79b4..d0f7acbd02e 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1634,6 +1634,8 @@ class CondContext(ControlFlowContext): # pylint: disable=protected-access op._update_input(index, real_x) # pylint: enable=protected-access + # Remove any external control dependency on this op. + self._RemoveExternalControlEdges(op) for x in op.outputs: self._values.add(x.name) # pylint: disable=protected-access