Stop supporting cond v1 in automatic control dependencies.
Will allow us to simplify the code quite a bit. PiperOrigin-RevId: 239280202
This commit is contained in:
parent
8b53eb536e
commit
a10ac56ab6
@ -23,7 +23,6 @@ from tensorflow.python.framework import dtypes as dtypes_module
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -168,65 +167,6 @@ class AutomaticControlDependencies(object):
|
|||||||
self._n_operations = len(self._graph.get_operations())
|
self._n_operations = len(self._graph.get_operations())
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _process_switch(self, switch_op, ops_which_must_run,
|
|
||||||
last_op_using_resource_tensor, merge_for_resource):
|
|
||||||
"""Processes a switch node for a resource input.
|
|
||||||
|
|
||||||
When tensorflow creates a cond, it creates a control flow context for each
|
|
||||||
branch of the cond. Each external tensor accessed by that branch is routed
|
|
||||||
through a switch op, which gets created in the graph _after_ the op which
|
|
||||||
uses that tensor get created.
|
|
||||||
|
|
||||||
If the resource comes from another switch op we process that one first.
|
|
||||||
|
|
||||||
_process_switch creates a corresponding merge node for the switch node. This
|
|
||||||
merge node is added to the outer control flow context of the switch
|
|
||||||
node. We also ensure that:
|
|
||||||
|
|
||||||
1. The switch node executes after the previous op which used the resource
|
|
||||||
tensor
|
|
||||||
|
|
||||||
2. Any op which uses a resource output of the switch node executes before
|
|
||||||
the merge for the switch node.
|
|
||||||
|
|
||||||
3. The next op which uses the input resource to the switch node (which
|
|
||||||
might be another switch node for the other branch of the conditional)
|
|
||||||
will execute after the merge node is done.
|
|
||||||
|
|
||||||
4. The merge node is marked as must_run so it will run even if no
|
|
||||||
subsequent operation uses the resource.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
switch_op: the switch op to be processed
|
|
||||||
ops_which_must_run: the set of ops which must run
|
|
||||||
last_op_using_resource_tensor: map from resource tensor to last op using
|
|
||||||
it
|
|
||||||
merge_for_resource: map from resource tensor to merge which must follow
|
|
||||||
all usages of it.
|
|
||||||
"""
|
|
||||||
inp = switch_op.inputs[0]
|
|
||||||
if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
|
|
||||||
self._process_switch(inp.op, ops_which_must_run,
|
|
||||||
last_op_using_resource_tensor, merge_for_resource)
|
|
||||||
if switch_op.outputs[0] in merge_for_resource:
|
|
||||||
return
|
|
||||||
new_merge = control_flow_ops.merge(switch_op.outputs,
|
|
||||||
name="artificial_merge")
|
|
||||||
new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access
|
|
||||||
switch_op._control_flow_context.outer_context) # pylint: disable=protected-access
|
|
||||||
# Ensures the merge always runs
|
|
||||||
ops_which_must_run.add(new_merge[0].op)
|
|
||||||
if inp in last_op_using_resource_tensor:
|
|
||||||
# Ensures the switch executes after the previous op using the resource.
|
|
||||||
switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access
|
|
||||||
# Ensure the next op outside the cond happens after the merge.
|
|
||||||
last_op_using_resource_tensor[inp] = new_merge[0].op
|
|
||||||
if inp in merge_for_resource:
|
|
||||||
merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access
|
|
||||||
for o in switch_op.outputs:
|
|
||||||
# Ensures the merge will execute after all ops inside the cond
|
|
||||||
merge_for_resource[o] = new_merge[0].op
|
|
||||||
|
|
||||||
def __exit__(self, unused_type, unused_value, unused_traceback):
|
def __exit__(self, unused_type, unused_value, unused_traceback):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return
|
return
|
||||||
@ -247,8 +187,6 @@ class AutomaticControlDependencies(object):
|
|||||||
last_op_using_resource_tensor = {}
|
last_op_using_resource_tensor = {}
|
||||||
# set of conditional and loop exits
|
# set of conditional and loop exits
|
||||||
ops_which_must_run = set()
|
ops_which_must_run = set()
|
||||||
# merge which must depend on ops which use this resource
|
|
||||||
merge_for_resource = {}
|
|
||||||
|
|
||||||
new_operations = self._graph.get_operations()[self._n_operations:]
|
new_operations = self._graph.get_operations()[self._n_operations:]
|
||||||
|
|
||||||
@ -290,16 +228,7 @@ class AutomaticControlDependencies(object):
|
|||||||
or op_is_stateful(self._graph._registered_ops[op.type])): # pylint: disable=protected-access
|
or op_is_stateful(self._graph._registered_ops[op.type])): # pylint: disable=protected-access
|
||||||
ops_which_must_run.add(op)
|
ops_which_must_run.add(op)
|
||||||
# Ignore switches (they're handled separately)
|
# Ignore switches (they're handled separately)
|
||||||
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
|
if op.type in ("Switch", "Merge", "Enter", "Exit", "NextIteration"):
|
||||||
continue
|
|
||||||
# Make merges trigger all other computation which must run
|
|
||||||
if op.type == "Merge":
|
|
||||||
for o in ops_which_must_run:
|
|
||||||
op._add_control_input(o) # pylint: disable=protected-access
|
|
||||||
for inp in o.inputs:
|
|
||||||
if inp in last_op_using_resource_tensor:
|
|
||||||
last_op_using_resource_tensor[inp] = op
|
|
||||||
ops_which_must_run = set([op])
|
|
||||||
continue
|
continue
|
||||||
found_resource = False
|
found_resource = False
|
||||||
# Check for any resource inputs. If we find any, we update control_inputs
|
# Check for any resource inputs. If we find any, we update control_inputs
|
||||||
@ -310,19 +239,11 @@ class AutomaticControlDependencies(object):
|
|||||||
if inp.dtype != dtypes_module.resource:
|
if inp.dtype != dtypes_module.resource:
|
||||||
continue
|
continue
|
||||||
found_resource = True
|
found_resource = True
|
||||||
# Deal with switches, finally.
|
|
||||||
if inp.op.type == "Switch":
|
|
||||||
self._process_switch(inp.op, ops_which_must_run,
|
|
||||||
last_op_using_resource_tensor,
|
|
||||||
merge_for_resource)
|
|
||||||
# Ensure uses of resources are serialized
|
# Ensure uses of resources are serialized
|
||||||
if inp in last_op_using_resource_tensor:
|
if inp in last_op_using_resource_tensor:
|
||||||
if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access
|
if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access
|
||||||
is op._control_flow_context): # pylint: disable=protected-access
|
is op._control_flow_context): # pylint: disable=protected-access
|
||||||
control_inputs.add(last_op_using_resource_tensor[inp])
|
control_inputs.add(last_op_using_resource_tensor[inp])
|
||||||
# Ensure merges happen after the closing of a cond block
|
|
||||||
if inp in merge_for_resource:
|
|
||||||
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
|
|
||||||
last_op_using_resource_tensor[inp] = op
|
last_op_using_resource_tensor[inp] = op
|
||||||
if (op_is_stateful(op.op_def) and not found_resource
|
if (op_is_stateful(op.op_def) and not found_resource
|
||||||
and op._control_flow_context is None): # pylint: disable=protected-access
|
and op._control_flow_context is None): # pylint: disable=protected-access
|
||||||
|
@ -26,9 +26,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.keras.layers import core as keras_core
|
from tensorflow.python.keras.layers import core as keras_core
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -50,152 +48,6 @@ class AutomaticControlDependenciesTest(test.TestCase):
|
|||||||
val = c.mark_as_return(val)
|
val = c.mark_as_return(val)
|
||||||
self.assertAllEqual(val.eval(), 4.0)
|
self.assertAllEqual(val.eval(), 4.0)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCondMustRun(self):
|
|
||||||
with context.graph_mode(), self.cached_session():
|
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
p = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
with acd.AutomaticControlDependencies() as c:
|
|
||||||
|
|
||||||
def true_fn():
|
|
||||||
v.assign(v + 1)
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def false_fn():
|
|
||||||
v.assign(v + 4)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(p, true_fn, false_fn)
|
|
||||||
val = v.read_value()
|
|
||||||
val = c.mark_as_return(val)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCondMustRunSeparateRead(self):
|
|
||||||
with context.graph_mode(), self.cached_session():
|
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
p = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
with acd.AutomaticControlDependencies() as c:
|
|
||||||
|
|
||||||
def true_fn():
|
|
||||||
v.assign(v + 1)
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def false_fn():
|
|
||||||
v.assign(v + 4)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(p, true_fn, false_fn)
|
|
||||||
one = constant_op.constant(1.0)
|
|
||||||
one = c.mark_as_return(one)
|
|
||||||
one.eval(feed_dict={p: False})
|
|
||||||
self.assertAllEqual(v.read_value().eval(), 5.0)
|
|
||||||
one.eval(feed_dict={p: True})
|
|
||||||
self.assertAllEqual(v.read_value().eval(), 6.0)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCondNested(self):
|
|
||||||
with context.graph_mode(), self.cached_session():
|
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
p = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
q = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
with acd.AutomaticControlDependencies() as c:
|
|
||||||
|
|
||||||
def true_fn():
|
|
||||||
v.assign(v + 1, name='true')
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
def false_fn():
|
|
||||||
|
|
||||||
def inner_true_fn():
|
|
||||||
v.assign(v * 2, name='false_true')
|
|
||||||
return 2.0
|
|
||||||
|
|
||||||
def inner_false_fn():
|
|
||||||
v.assign(v * 3, name='false_false')
|
|
||||||
return 3.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(q, inner_true_fn, inner_false_fn)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(p, true_fn, false_fn)
|
|
||||||
with ops.name_scope('final'):
|
|
||||||
val = v.read_value()
|
|
||||||
val = c.mark_as_return(val)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCondOneBranch(self):
|
|
||||||
with context.graph_mode(), self.cached_session():
|
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
p = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
with acd.AutomaticControlDependencies() as c:
|
|
||||||
|
|
||||||
def true_fn():
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def false_fn():
|
|
||||||
v.assign(v + 4)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(p, true_fn, false_fn)
|
|
||||||
val = v.read_value()
|
|
||||||
val = c.mark_as_return(val)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCondOneBranchUpdateBefore(self):
|
|
||||||
with context.graph_mode(), self.cached_session():
|
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
p = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
with acd.AutomaticControlDependencies() as c:
|
|
||||||
v.assign(v * 2)
|
|
||||||
|
|
||||||
def true_fn():
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def false_fn():
|
|
||||||
v.assign(v + 4)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(p, true_fn, false_fn)
|
|
||||||
val = v.read_value()
|
|
||||||
val = c.mark_as_return(val)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCondOneBranchUpdateAfter(self):
|
|
||||||
with context.graph_mode(), self.cached_session():
|
|
||||||
v = resource_variable_ops.ResourceVariable(1.0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
p = array_ops.placeholder(dtype=dtypes.bool)
|
|
||||||
with acd.AutomaticControlDependencies() as c:
|
|
||||||
|
|
||||||
def true_fn():
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def false_fn():
|
|
||||||
v.assign(v + 4)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
control_flow_ops.cond(p, true_fn, false_fn)
|
|
||||||
v.assign(v * 2)
|
|
||||||
val = v.read_value()
|
|
||||||
val = c.mark_as_return(val)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0)
|
|
||||||
self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0)
|
|
||||||
|
|
||||||
def testDefunWhileLoopWithCapturedLoopVars(self):
|
def testDefunWhileLoopWithCapturedLoopVars(self):
|
||||||
n = 3
|
n = 3
|
||||||
x = constant_op.constant(list(range(n)))
|
x = constant_op.constant(list(range(n)))
|
||||||
|
Loading…
Reference in New Issue
Block a user