Make sure there is a control dependency from the v1 while_loop pivot to the Const nodes created inside functional ops.
This is needed to prevent frame mismatch errors where there are Const nodes inside tf.function in v1 while_loop and inlining is turned on. PiperOrigin-RevId: 302974817 Change-Id: I740de2ce104a42f66a0fd4be284afd9fd34c30cf
This commit is contained in:
parent
5d169b40e9
commit
a47cd610e3
@ -3011,6 +3011,25 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(f(), 4. * 2.**3) # 4 * x_init ^ 3
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTfFunctionInV1WhileLoop(self):
|
||||
|
||||
# This test specifically tests that creating a Const node inside a
|
||||
# tf.function inside a v1 while_loop while inlining is turned on works.
|
||||
config = opt_cfg()
|
||||
assert config.graph_options.optimizer_options.do_function_inlining
|
||||
with session.Session(config=config):
|
||||
|
||||
@def_function.function
|
||||
def loop_body(i):
|
||||
# Here we create the const.
|
||||
return i + 1.
|
||||
|
||||
loop_cond = lambda i: True
|
||||
x = control_flow_ops.while_loop(
|
||||
loop_cond, loop_body, [0.], maximum_iterations=5)
|
||||
self.assertAllEqual(x, 5.)
|
||||
|
||||
def _testNestedWhileCondWhileGrad(self, use_gpu):
|
||||
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
|
@ -1724,6 +1724,10 @@ class WhileContext(ControlFlowContext):
|
||||
We move any external control dependencies of the op to the loop pivot, to
|
||||
ensure they get executed.
|
||||
"""
|
||||
# This is needed to prevent frame mismatch errors where there are Const
|
||||
# nodes inside tf.function in v1 while_loop and inlining is turned on.
|
||||
if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
|
||||
op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access
|
||||
if not op.inputs:
|
||||
# Remove any external control dependency on this op
|
||||
control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
|
||||
|
Loading…
Reference in New Issue
Block a user