From acfb05435f233a21f5f904e5aa9ea3631b1ac58e Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Tue, 6 Oct 2020 10:40:35 -0700 Subject: [PATCH] Add read variable op to function control outputs After inlining, we convert control dependencies to the function call to depend on all control outputs. Although read variable ops doesn't have side effect itself, it must be executed before variable updates after the function call. Note that it's not enough to convert control dependencies to the function call to depend on both control outputs and data outputs of the function call. The read variable ops can be inputs to ops that have side effects, e.g. assert and print, which are not the function data outputs. PiperOrigin-RevId: 335671909 Change-Id: I6fa88e6ffe6997aa1f4e7bcebf4089ead028b11d --- tensorflow/python/eager/function_test.py | 22 ++++++++++++++++ .../python/framework/auto_control_deps.py | 11 ++++---- .../framework/auto_control_deps_test.py | 25 ++++++------------- .../python/grappler/constant_folding_test.py | 8 ++---- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 69433bc9ce3..bd91c6a564e 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -4291,6 +4291,28 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(TypeError, 'missing required arguments: y'): foo.add(2) # pylint: disable=no-value-for-parameter + @test_util.run_v2_only + def testControlDependencyAfterInline(self): + v = variables.Variable(0.) + + @def_function.function + def assign(): + return v.assign(1.) + + @def_function.function + def assign_add(): + return v.assign_add(1.) + + @def_function.function + def f(): + check_ops.assert_equal_v2(assign(), 1.) + check_ops.assert_equal_v2(assign_add(), 2.) + + # We don't have a way to inspect the inlined graph in Python, so we run it + # multiple times to have more confidence the dependency is correct. + for _ in range(30): + f() + class MultiDeviceTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index 12f35a494b1..ba6a67910b1 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -373,11 +373,12 @@ class AutomaticControlDependencies(object): if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() - # Ensure stateful ops run - if (op_def_registry.get(op.type) is None or - (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)): - # TODO(srbs): Do not add functional ops to `ops_which_must_run` if - # they only have variable reads and are otherwise stateless. + # Ensure stateful ops run. Note that this includes read only ops, although + # they don't have direct side effect, they are affected by ops that writes + # the same resource and may be inputs to side-effect ops like tf.print. If + # the function gets inlined, they must execute before ops that depend on + # the function call. + if op_def_registry.get(op.type) is None or op_is_stateful(op): ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py index dc5d8986958..8b549263229 100644 --- a/tensorflow/python/framework/auto_control_deps_test.py +++ b/tensorflow/python/framework/auto_control_deps_test.py @@ -102,23 +102,14 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertNotIn(read_op1, read_op2.control_inputs) self.assertNotIn(read_op2, read_op1.control_inputs) - def testVariableReadsNotInOpsWithMustRun(self): + def testVariableReadsInOpsWithMustRun(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: - read_op1 = gen_resource_variable_ops.read_variable_op( + read_op = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op - read_op2 = gen_resource_variable_ops.read_variable_op( - v.handle, v.dtype).op - assign_op = gen_resource_variable_ops.assign_variable_op( - v.handle, v + 1) - # Reads must not be in `ops_which_must_run` since those get added to the - # `control_outputs`. - self.assertNotIn(read_op1, c.ops_which_must_run) - self.assertNotIn(read_op2, c.ops_which_must_run) - # Last write must be in `ops_which_must_run`. - self.assertIn(assign_op, c.ops_which_must_run) + self.assertIn(read_op, c.ops_which_must_run) def testVariableMultipleReadsAndWrites(self): with context.graph_mode(), self.cached_session(): @@ -158,11 +149,11 @@ class AutomaticControlDependenciesTest(test.TestCase): for src_op, tgt_op in itertools.product(read_ops, read_ops): self.assertNotIn(src_op, tgt_op.control_inputs) - # Reads must not be in `ops_which_must_run`. - self.assertNotIn(read_op1, c.ops_which_must_run) - self.assertNotIn(read_op2, c.ops_which_must_run) - self.assertNotIn(read_op3, c.ops_which_must_run) - self.assertNotIn(read_op4, c.ops_which_must_run) + # Reads must be in `ops_which_must_run`. + self.assertIn(read_op1, c.ops_which_must_run) + self.assertIn(read_op2, c.ops_which_must_run) + self.assertIn(read_op3, c.ops_which_must_run) + self.assertIn(read_op4, c.ops_which_must_run) # Last write must be in `ops_which_must_run`. self.assertIn(assign_op4, c.ops_which_must_run) diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py index 3336d3f7e8f..b5a9e7c8e48 100644 --- a/tensorflow/python/grappler/constant_folding_test.py +++ b/tensorflow/python/grappler/constant_folding_test.py @@ -96,19 +96,15 @@ class ConstantFoldingTest(test.TestCase): f(x, y).numpy() self.assertLen(graphs, 1) assign_count = 0 - read_count = 0 for node in graphs[0].node: if node.op == 'AssignAddVariableOp': self.assertEqual(node.input[0], 'y') assign_count += 1 - if node.op == 'ReadVariableOp': - read_count += 1 # Make sure that the only variable update that remains after - # grappler optimization is that of y, and that we prune all - # but the 2 necessary variable reads. + # grappler optimization is that of y. self.assertEqual(assign_count, 1) - self.assertEqual(read_count, 2) + self.assertLen(graphs[0].node, 11) if __name__ == '__main__':