Make sure there is a control dependency from the v1 while_loop pivot to the Const nodes created inside functional ops.

This is needed to prevent frame mismatch errors where there are Const nodes inside tf.function in v1 while_loop and inlining is turned on.

PiperOrigin-RevId: 302974817
Change-Id: I740de2ce104a42f66a0fd4be284afd9fd34c30cf
This commit is contained in:
Saurabh Saxena 2020-03-25 14:46:31 -07:00 committed by TensorFlower Gardener
parent 5d169b40e9
commit a47cd610e3
2 changed files with 23 additions and 0 deletions

View File

@ -3011,6 +3011,25 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(f(), 4. * 2.**3) # 4 * x_init ^ 3
@test_util.run_deprecated_v1
def testTfFunctionInV1WhileLoop(self):
# This test specifically tests that creating a Const node inside a
# tf.function inside a v1 while_loop while inlining is turned on works.
config = opt_cfg()
assert config.graph_options.optimizer_options.do_function_inlining
with session.Session(config=config):
@def_function.function
def loop_body(i):
# Here we create the const.
return i + 1.
loop_cond = lambda i: True
x = control_flow_ops.while_loop(
loop_cond, loop_body, [0.], maximum_iterations=5)
self.assertAllEqual(x, 5.)
def _testNestedWhileCondWhileGrad(self, use_gpu):
with self.cached_session(use_gpu=use_gpu):

View File

@ -1724,6 +1724,10 @@ class WhileContext(ControlFlowContext):
We move any external control dependencies of the op to the loop pivot, to
ensure they get executed.
"""
# This is needed to prevent frame mismatch errors where there are Const
# nodes inside tf.function in v1 while_loop and inlining is turned on.
if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access
if not op.inputs:
# Remove any external control dependency on this op
control_inputs, external_inputs = self._RemoveExternalControlEdges(op)