diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 69433bc9ce3..bd91c6a564e 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -4291,6 +4291,28 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(TypeError, 'missing required arguments: y'): foo.add(2) # pylint: disable=no-value-for-parameter + @test_util.run_v2_only + def testControlDependencyAfterInline(self): + v = variables.Variable(0.) + + @def_function.function + def assign(): + return v.assign(1.) + + @def_function.function + def assign_add(): + return v.assign_add(1.) + + @def_function.function + def f(): + check_ops.assert_equal_v2(assign(), 1.) + check_ops.assert_equal_v2(assign_add(), 2.) + + # We don't have a way to inspect the inlined graph in Python, so we run it + # multiple times to have more confidence the dependency is correct. + for _ in range(30): + f() + class MultiDeviceTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index 12f35a494b1..ba6a67910b1 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -373,11 +373,12 @@ class AutomaticControlDependencies(object): if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() - # Ensure stateful ops run - if (op_def_registry.get(op.type) is None or - (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)): - # TODO(srbs): Do not add functional ops to `ops_which_must_run` if - # they only have variable reads and are otherwise stateless. + # Ensure stateful ops run. Note that this includes read only ops, although + # they don't have direct side effect, they are affected by ops that writes + # the same resource and may be inputs to side-effect ops like tf.print. If + # the function gets inlined, they must execute before ops that depend on + # the function call. + if op_def_registry.get(op.type) is None or op_is_stateful(op): ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py index dc5d8986958..8b549263229 100644 --- a/tensorflow/python/framework/auto_control_deps_test.py +++ b/tensorflow/python/framework/auto_control_deps_test.py @@ -102,23 +102,14 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertNotIn(read_op1, read_op2.control_inputs) self.assertNotIn(read_op2, read_op1.control_inputs) - def testVariableReadsNotInOpsWithMustRun(self): + def testVariableReadsInOpsWithMustRun(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: - read_op1 = gen_resource_variable_ops.read_variable_op( + read_op = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op - read_op2 = gen_resource_variable_ops.read_variable_op( - v.handle, v.dtype).op - assign_op = gen_resource_variable_ops.assign_variable_op( - v.handle, v + 1) - # Reads must not be in `ops_which_must_run` since those get added to the - # `control_outputs`. - self.assertNotIn(read_op1, c.ops_which_must_run) - self.assertNotIn(read_op2, c.ops_which_must_run) - # Last write must be in `ops_which_must_run`. - self.assertIn(assign_op, c.ops_which_must_run) + self.assertIn(read_op, c.ops_which_must_run) def testVariableMultipleReadsAndWrites(self): with context.graph_mode(), self.cached_session(): @@ -158,11 +149,11 @@ class AutomaticControlDependenciesTest(test.TestCase): for src_op, tgt_op in itertools.product(read_ops, read_ops): self.assertNotIn(src_op, tgt_op.control_inputs) - # Reads must not be in `ops_which_must_run`. - self.assertNotIn(read_op1, c.ops_which_must_run) - self.assertNotIn(read_op2, c.ops_which_must_run) - self.assertNotIn(read_op3, c.ops_which_must_run) - self.assertNotIn(read_op4, c.ops_which_must_run) + # Reads must be in `ops_which_must_run`. + self.assertIn(read_op1, c.ops_which_must_run) + self.assertIn(read_op2, c.ops_which_must_run) + self.assertIn(read_op3, c.ops_which_must_run) + self.assertIn(read_op4, c.ops_which_must_run) # Last write must be in `ops_which_must_run`. self.assertIn(assign_op4, c.ops_which_must_run) diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py index 3336d3f7e8f..b5a9e7c8e48 100644 --- a/tensorflow/python/grappler/constant_folding_test.py +++ b/tensorflow/python/grappler/constant_folding_test.py @@ -96,19 +96,15 @@ class ConstantFoldingTest(test.TestCase): f(x, y).numpy() self.assertLen(graphs, 1) assign_count = 0 - read_count = 0 for node in graphs[0].node: if node.op == 'AssignAddVariableOp': self.assertEqual(node.input[0], 'y') assign_count += 1 - if node.op == 'ReadVariableOp': - read_count += 1 # Make sure that the only variable update that remains after - # grappler optimization is that of y, and that we prune all - # but the 2 necessary variable reads. + # grappler optimization is that of y. self.assertEqual(assign_count, 1) - self.assertEqual(read_count, 2) + self.assertLen(graphs[0].node, 11) if __name__ == '__main__':