From 06c6f5ffc585d82ee6f313fd6f29d204bc75b25e Mon Sep 17 00:00:00 2001 From: Shohini Ghosh Date: Wed, 20 Mar 2019 17:28:06 -0700 Subject: [PATCH] Automated rollback of commit a10ac56ab6999ed8ce55ef28879aa3cb3cefc43b PiperOrigin-RevId: 239509940 --- .../python/framework/auto_control_deps.py | 81 +++++++++- .../framework/auto_control_deps_test.py | 148 ++++++++++++++++++ 2 files changed, 228 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index bcf9e5bd3da..a8ba4ea50d1 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -167,6 +168,65 @@ class AutomaticControlDependencies(object): self._n_operations = len(self._graph.get_operations()) return self + def _process_switch(self, switch_op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource): + """Processes a switch node for a resource input. + + When tensorflow creates a cond, it creates a control flow context for each + branch of the cond. Each external tensor accessed by that branch is routed + through a switch op, which gets created in the graph _after_ the op which + uses that tensor get created. + + If the resource comes from another switch op we process that one first. + + _process_switch creates a corresponding merge node for the switch node. This + merge node is added to the outer control flow context of the switch + node. We also ensure that: + + 1. The switch node executes after the previous op which used the resource + tensor + + 2. Any op which uses a resource output of the switch node executes before + the merge for the switch node. + + 3. The next op which uses the input resource to the switch node (which + might be another switch node for the other branch of the conditional) + will execute after the merge node is done. + + 4. The merge node is marked as must_run so it will run even if no + subsequent operation uses the resource. + + Args: + switch_op: the switch op to be processed + ops_which_must_run: the set of ops which must run + last_op_using_resource_tensor: map from resource tensor to last op using + it + merge_for_resource: map from resource tensor to merge which must follow + all usages of it. + """ + inp = switch_op.inputs[0] + if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource) + if switch_op.outputs[0] in merge_for_resource: + return + new_merge = control_flow_ops.merge(switch_op.outputs, + name="artificial_merge") + new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access + switch_op._control_flow_context.outer_context) # pylint: disable=protected-access + # Ensures the merge always runs + ops_which_must_run.add(new_merge[0].op) + if inp in last_op_using_resource_tensor: + # Ensures the switch executes after the previous op using the resource. + switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access + # Ensure the next op outside the cond happens after the merge. + last_op_using_resource_tensor[inp] = new_merge[0].op + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access + for o in switch_op.outputs: + # Ensures the merge will execute after all ops inside the cond + merge_for_resource[o] = new_merge[0].op + def __exit__(self, unused_type, unused_value, unused_traceback): if context.executing_eagerly(): return @@ -187,6 +247,8 @@ class AutomaticControlDependencies(object): last_op_using_resource_tensor = {} # set of conditional and loop exits ops_which_must_run = set() + # merge which must depend on ops which use this resource + merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] @@ -228,7 +290,16 @@ class AutomaticControlDependencies(object): or op_is_stateful(self._graph._registered_ops[op.type])): # pylint: disable=protected-access ops_which_must_run.add(op) # Ignore switches (they're handled separately) - if op.type in ("Switch", "Merge", "Enter", "Exit", "NextIteration"): + if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: + continue + # Make merges trigger all other computation which must run + if op.type == "Merge": + for o in ops_which_must_run: + op._add_control_input(o) # pylint: disable=protected-access + for inp in o.inputs: + if inp in last_op_using_resource_tensor: + last_op_using_resource_tensor[inp] = op + ops_which_must_run = set([op]) continue found_resource = False # Check for any resource inputs. If we find any, we update control_inputs @@ -239,11 +310,19 @@ class AutomaticControlDependencies(object): if inp.dtype != dtypes_module.resource: continue found_resource = True + # Deal with switches, finally. + if inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, + merge_for_resource) # Ensure uses of resources are serialized if inp in last_op_using_resource_tensor: if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access is op._control_flow_context): # pylint: disable=protected-access control_inputs.add(last_op_using_resource_tensor[inp]) + # Ensure merges happen after the closing of a cond block + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[inp] = op if (op_is_stateful(op.op_def) and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py index c8c5e18c7ce..19c0606eb42 100644 --- a/tensorflow/python/framework/auto_control_deps_test.py +++ b/tensorflow/python/framework/auto_control_deps_test.py @@ -26,7 +26,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util from tensorflow.python.keras.layers import core as keras_core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -48,6 +50,152 @@ class AutomaticControlDependenciesTest(test.TestCase): val = c.mark_as_return(val) self.assertAllEqual(val.eval(), 4.0) + @test_util.run_v1_only("b/120545219") + def testCondMustRun(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + p = array_ops.placeholder(dtype=dtypes.bool) + with acd.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + val = c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) + + @test_util.run_v1_only("b/120545219") + def testCondMustRunSeparateRead(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + p = array_ops.placeholder(dtype=dtypes.bool) + with acd.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + one = constant_op.constant(1.0) + one = c.mark_as_return(one) + one.eval(feed_dict={p: False}) + self.assertAllEqual(v.read_value().eval(), 5.0) + one.eval(feed_dict={p: True}) + self.assertAllEqual(v.read_value().eval(), 6.0) + + @test_util.run_v1_only("b/120545219") + def testCondNested(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + p = array_ops.placeholder(dtype=dtypes.bool) + q = array_ops.placeholder(dtype=dtypes.bool) + with acd.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1, name="true") + return 1.0 + + def false_fn(): + + def inner_true_fn(): + v.assign(v * 2, name="false_true") + return 2.0 + + def inner_false_fn(): + v.assign(v * 3, name="false_false") + return 3.0 + + control_flow_ops.cond(q, inner_true_fn, inner_false_fn) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + with ops.name_scope("final"): + val = v.read_value() + val = c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) + self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) + + @test_util.run_v1_only("b/120545219") + def testCondOneBranch(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + p = array_ops.placeholder(dtype=dtypes.bool) + with acd.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + val = c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) + + @test_util.run_v1_only("b/120545219") + def testCondOneBranchUpdateBefore(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + p = array_ops.placeholder(dtype=dtypes.bool) + with acd.AutomaticControlDependencies() as c: + v.assign(v * 2) + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + val = c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) + + @test_util.run_v1_only("b/120545219") + def testCondOneBranchUpdateAfter(self): + with context.graph_mode(), self.cached_session(): + v = resource_variable_ops.ResourceVariable(1.0) + self.evaluate(variables.global_variables_initializer()) + p = array_ops.placeholder(dtype=dtypes.bool) + with acd.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + v.assign(v * 2) + val = v.read_value() + val = c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) + def testDefunWhileLoopWithCapturedLoopVars(self): n = 3 x = constant_op.constant(list(range(n)))