Stop supporting cond v1 in automatic control dependencies.

Will allow us to simplify the code quite a bit.

PiperOrigin-RevId: 239280202
This commit is contained in:
Alexandre Passos 2019-03-19 15:16:39 -07:00 committed by TensorFlower Gardener
parent 8b53eb536e
commit a10ac56ab6
2 changed files with 1 additions and 228 deletions

View File

@ -23,7 +23,6 @@ from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -168,65 +167,6 @@ class AutomaticControlDependencies(object):
self._n_operations = len(self._graph.get_operations()) self._n_operations = len(self._graph.get_operations())
return self return self
def _process_switch(self, switch_op, ops_which_must_run,
last_op_using_resource_tensor, merge_for_resource):
"""Processes a switch node for a resource input.
When tensorflow creates a cond, it creates a control flow context for each
branch of the cond. Each external tensor accessed by that branch is routed
through a switch op, which gets created in the graph _after_ the op which
uses that tensor get created.
If the resource comes from another switch op we process that one first.
_process_switch creates a corresponding merge node for the switch node. This
merge node is added to the outer control flow context of the switch
node. We also ensure that:
1. The switch node executes after the previous op which used the resource
tensor
2. Any op which uses a resource output of the switch node executes before
the merge for the switch node.
3. The next op which uses the input resource to the switch node (which
might be another switch node for the other branch of the conditional)
will execute after the merge node is done.
4. The merge node is marked as must_run so it will run even if no
subsequent operation uses the resource.
Args:
switch_op: the switch op to be processed
ops_which_must_run: the set of ops which must run
last_op_using_resource_tensor: map from resource tensor to last op using
it
merge_for_resource: map from resource tensor to merge which must follow
all usages of it.
"""
inp = switch_op.inputs[0]
if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
last_op_using_resource_tensor, merge_for_resource)
if switch_op.outputs[0] in merge_for_resource:
return
new_merge = control_flow_ops.merge(switch_op.outputs,
name="artificial_merge")
new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access
switch_op._control_flow_context.outer_context) # pylint: disable=protected-access
# Ensures the merge always runs
ops_which_must_run.add(new_merge[0].op)
if inp in last_op_using_resource_tensor:
# Ensures the switch executes after the previous op using the resource.
switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access
# Ensure the next op outside the cond happens after the merge.
last_op_using_resource_tensor[inp] = new_merge[0].op
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access
for o in switch_op.outputs:
# Ensures the merge will execute after all ops inside the cond
merge_for_resource[o] = new_merge[0].op
def __exit__(self, unused_type, unused_value, unused_traceback): def __exit__(self, unused_type, unused_value, unused_traceback):
if context.executing_eagerly(): if context.executing_eagerly():
return return
@ -247,8 +187,6 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor = {} last_op_using_resource_tensor = {}
# set of conditional and loop exits # set of conditional and loop exits
ops_which_must_run = set() ops_which_must_run = set()
# merge which must depend on ops which use this resource
merge_for_resource = {}
new_operations = self._graph.get_operations()[self._n_operations:] new_operations = self._graph.get_operations()[self._n_operations:]
@ -290,16 +228,7 @@ class AutomaticControlDependencies(object):
or op_is_stateful(self._graph._registered_ops[op.type])): # pylint: disable=protected-access or op_is_stateful(self._graph._registered_ops[op.type])): # pylint: disable=protected-access
ops_which_must_run.add(op) ops_which_must_run.add(op)
# Ignore switches (they're handled separately) # Ignore switches (they're handled separately)
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: if op.type in ("Switch", "Merge", "Enter", "Exit", "NextIteration"):
continue
# Make merges trigger all other computation which must run
if op.type == "Merge":
for o in ops_which_must_run:
op._add_control_input(o) # pylint: disable=protected-access
for inp in o.inputs:
if inp in last_op_using_resource_tensor:
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue continue
found_resource = False found_resource = False
# Check for any resource inputs. If we find any, we update control_inputs # Check for any resource inputs. If we find any, we update control_inputs
@ -310,19 +239,11 @@ class AutomaticControlDependencies(object):
if inp.dtype != dtypes_module.resource: if inp.dtype != dtypes_module.resource:
continue continue
found_resource = True found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
last_op_using_resource_tensor,
merge_for_resource)
# Ensure uses of resources are serialized # Ensure uses of resources are serialized
if inp in last_op_using_resource_tensor: if inp in last_op_using_resource_tensor:
if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access
is op._control_flow_context): # pylint: disable=protected-access is op._control_flow_context): # pylint: disable=protected-access
control_inputs.add(last_op_using_resource_tensor[inp]) control_inputs.add(last_op_using_resource_tensor[inp])
# Ensure merges happen after the closing of a cond block
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op last_op_using_resource_tensor[inp] = op
if (op_is_stateful(op.op_def) and not found_resource if (op_is_stateful(op.op_def) and not found_resource
and op._control_flow_context is None): # pylint: disable=protected-access and op._control_flow_context is None): # pylint: disable=protected-access

View File

@ -26,9 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.keras.layers import core as keras_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -50,152 +48,6 @@ class AutomaticControlDependenciesTest(test.TestCase):
val = c.mark_as_return(val) val = c.mark_as_return(val)
self.assertAllEqual(val.eval(), 4.0) self.assertAllEqual(val.eval(), 4.0)
@test_util.run_v1_only("b/120545219")
def testCondMustRun(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
p = array_ops.placeholder(dtype=dtypes.bool)
with acd.AutomaticControlDependencies() as c:
def true_fn():
v.assign(v + 1)
return 0.0
def false_fn():
v.assign(v + 4)
return 1.0
control_flow_ops.cond(p, true_fn, false_fn)
val = v.read_value()
val = c.mark_as_return(val)
self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
@test_util.run_v1_only("b/120545219")
def testCondMustRunSeparateRead(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
p = array_ops.placeholder(dtype=dtypes.bool)
with acd.AutomaticControlDependencies() as c:
def true_fn():
v.assign(v + 1)
return 0.0
def false_fn():
v.assign(v + 4)
return 1.0
control_flow_ops.cond(p, true_fn, false_fn)
one = constant_op.constant(1.0)
one = c.mark_as_return(one)
one.eval(feed_dict={p: False})
self.assertAllEqual(v.read_value().eval(), 5.0)
one.eval(feed_dict={p: True})
self.assertAllEqual(v.read_value().eval(), 6.0)
@test_util.run_v1_only("b/120545219")
def testCondNested(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
p = array_ops.placeholder(dtype=dtypes.bool)
q = array_ops.placeholder(dtype=dtypes.bool)
with acd.AutomaticControlDependencies() as c:
def true_fn():
v.assign(v + 1, name='true')
return 1.0
def false_fn():
def inner_true_fn():
v.assign(v * 2, name='false_true')
return 2.0
def inner_false_fn():
v.assign(v * 3, name='false_false')
return 3.0
control_flow_ops.cond(q, inner_true_fn, inner_false_fn)
return 1.0
control_flow_ops.cond(p, true_fn, false_fn)
with ops.name_scope('final'):
val = v.read_value()
val = c.mark_as_return(val)
self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0)
self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0)
self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0)
self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
@test_util.run_v1_only("b/120545219")
def testCondOneBranch(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
p = array_ops.placeholder(dtype=dtypes.bool)
with acd.AutomaticControlDependencies() as c:
def true_fn():
return 0.0
def false_fn():
v.assign(v + 4)
return 1.0
control_flow_ops.cond(p, true_fn, false_fn)
val = v.read_value()
val = c.mark_as_return(val)
self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
@test_util.run_v1_only("b/120545219")
def testCondOneBranchUpdateBefore(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
p = array_ops.placeholder(dtype=dtypes.bool)
with acd.AutomaticControlDependencies() as c:
v.assign(v * 2)
def true_fn():
return 0.0
def false_fn():
v.assign(v + 4)
return 1.0
control_flow_ops.cond(p, true_fn, false_fn)
val = v.read_value()
val = c.mark_as_return(val)
self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0)
self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
@test_util.run_v1_only("b/120545219")
def testCondOneBranchUpdateAfter(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
p = array_ops.placeholder(dtype=dtypes.bool)
with acd.AutomaticControlDependencies() as c:
def true_fn():
return 0.0
def false_fn():
v.assign(v + 4)
return 1.0
control_flow_ops.cond(p, true_fn, false_fn)
v.assign(v * 2)
val = v.read_value()
val = c.mark_as_return(val)
self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0)
self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0)
def testDefunWhileLoopWithCapturedLoopVars(self): def testDefunWhileLoopWithCapturedLoopVars(self):
n = 3 n = 3
x = constant_op.constant(list(range(n))) x = constant_op.constant(list(range(n)))