Support ordering of certain collectives.
PiperOrigin-RevId: 298719882 Change-Id: I53b83f2b26f431ef7d1258280d51c0604b19b50f
This commit is contained in:
parent
4614140908
commit
0d145969f6
tensorflow/python
@ -129,6 +129,32 @@ class ResourceType(enum.Enum):
|
||||
READ_WRITE = "read-write"
|
||||
|
||||
|
||||
def collective_manager_ids_from_op(op):
|
||||
"""Returns CollectiveManager ID from the op if one exists, else None.
|
||||
|
||||
CollectiveManager adds collective and no_op operations tagged with an ID,
|
||||
unique to the manager object. This function extracts that ID, or None, if the
|
||||
node was not generated by a CollectiveManager.
|
||||
|
||||
Args:
|
||||
op: `Operation` to get the collective manager ID from.
|
||||
|
||||
Returns:
|
||||
List of CollectiveManager IDs used by the op.
|
||||
"""
|
||||
if op.type == "CollectiveReduce":
|
||||
try:
|
||||
return [op.get_attr("_collective_manager_id")]
|
||||
except ValueError:
|
||||
pass
|
||||
elif op.type == "StatefulPartitionedCall":
|
||||
try:
|
||||
return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
|
||||
except ValueError:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
class AutomaticControlDependencies(object):
|
||||
"""Context manager to automatically add control dependencies.
|
||||
|
||||
@ -241,6 +267,7 @@ class AutomaticControlDependencies(object):
|
||||
merge_for_resource: map from resource tensor to merge which must follow
|
||||
all usages of it.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
inp = switch_op.inputs[0]
|
||||
input_id = ops.tensor_id(inp)
|
||||
if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
|
||||
@ -250,24 +277,25 @@ class AutomaticControlDependencies(object):
|
||||
output_id = ops.tensor_id(output)
|
||||
if output_id in merge_for_resource:
|
||||
return
|
||||
new_merge = control_flow_ops.merge(switch_op.outputs,
|
||||
name="artificial_merge")
|
||||
new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access
|
||||
switch_op._control_flow_context.outer_context) # pylint: disable=protected-access
|
||||
new_merge = control_flow_ops.merge(
|
||||
switch_op.outputs, name="artificial_merge")
|
||||
new_merge[0].op._control_flow_context = (
|
||||
switch_op._control_flow_context.outer_context)
|
||||
# Ensures the merge always runs
|
||||
ops_which_must_run.add(new_merge[0].op)
|
||||
if input_id in last_write_to_resource:
|
||||
# Ensures the switch executes after the previous op using the resource.
|
||||
switch_op._add_control_input(last_write_to_resource[input_id]) # pylint: disable=protected-access
|
||||
switch_op._add_control_input(last_write_to_resource[input_id])
|
||||
# Ensure the next op outside the cond happens after the merge.
|
||||
last_write_to_resource[input_id] = new_merge[0].op
|
||||
if input_id in merge_for_resource:
|
||||
merge_for_resource[input_id]._add_control_input(new_merge[0].op) # pylint: disable=protected-access
|
||||
merge_for_resource[input_id]._add_control_input(new_merge[0].op)
|
||||
for o in switch_op.outputs:
|
||||
# Ensures the merge will execute after all ops inside the cond
|
||||
merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
|
||||
|
||||
def __exit__(self, unused_type, unused_value, unused_traceback):
|
||||
# pylint: disable=protected-access
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
@ -275,19 +303,24 @@ class AutomaticControlDependencies(object):
|
||||
raise RuntimeError(
|
||||
"Graph changed while trying to add control dependencies.")
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if hasattr(self._graph, "outer_graph"):
|
||||
outer_val = self._graph.outer_graph._add_control_dependencies
|
||||
self._graph._add_control_dependencies = outer_val
|
||||
else:
|
||||
self._graph._add_control_dependencies = False
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# map from resource tensor to the last op which wrote to it
|
||||
last_write_to_resource = {}
|
||||
# map from resource tensor to the list of reads from it since the last
|
||||
# write or since the beginning of the function.
|
||||
reads_since_last_write_to_resource = collections.defaultdict(list)
|
||||
# CollectiveManager manager_ids within a particular function call should not
|
||||
# be needed outside of that function call. So we keep them separate (though
|
||||
# the general idea of the maps is the same, in the future, we'll need to
|
||||
# correctly thread the control output outside).
|
||||
# Map from collective manager scope to the last op which used it
|
||||
collective_manager_scopes_opened = {}
|
||||
collective_manager_scopes_used = {}
|
||||
# set of conditional and loop exits
|
||||
ops_which_must_run = set()
|
||||
# merge which must depend on ops which use this resource
|
||||
@ -334,13 +367,20 @@ class AutomaticControlDependencies(object):
|
||||
# TODO(srbs): Do not add functional ops to `ops_which_must_run` if
|
||||
# they only have variable reads and are otherwise stateless.
|
||||
ops_which_must_run.add(op)
|
||||
# Make a note of all opened manager_ids.
|
||||
if op.type == "NoOp":
|
||||
try:
|
||||
collective_manager_scopes_opened[op.get_attr(
|
||||
"_collective_manager_id")] = op
|
||||
except ValueError:
|
||||
pass
|
||||
# Ignore switches (they're handled separately)
|
||||
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
|
||||
continue
|
||||
# Make merges trigger all other computation which must run
|
||||
if op.type == "Merge":
|
||||
for o in ops_which_must_run:
|
||||
op._add_control_input(o) # pylint: disable=protected-access
|
||||
op._add_control_input(o)
|
||||
for inp in o.inputs:
|
||||
input_id = ops.tensor_id(inp)
|
||||
if input_id in last_write_to_resource:
|
||||
@ -369,12 +409,12 @@ class AutomaticControlDependencies(object):
|
||||
# Ensure uses of resources are serialized
|
||||
if input_id in last_write_to_resource:
|
||||
if is_building_function or (
|
||||
last_write_to_resource[input_id]._control_flow_context # pylint: disable=protected-access
|
||||
is op._control_flow_context): # pylint: disable=protected-access
|
||||
last_write_to_resource[input_id]._control_flow_context
|
||||
is op._control_flow_context):
|
||||
control_inputs.add(last_write_to_resource[input_id])
|
||||
# Ensure merges happen after the closing of a cond block
|
||||
if input_id in merge_for_resource:
|
||||
merge_for_resource[input_id]._add_control_input(op) # pylint: disable=protected-access
|
||||
merge_for_resource[input_id]._add_control_input(op)
|
||||
if is_read:
|
||||
reads_since_last_write_to_resource[input_id].append(op)
|
||||
else:
|
||||
@ -383,25 +423,48 @@ class AutomaticControlDependencies(object):
|
||||
last_write_to_resource[input_id] = op
|
||||
|
||||
if (op_is_stateful(op) and not resource_inputs
|
||||
and op._control_flow_context is None): # pylint: disable=protected-access
|
||||
and op._control_flow_context is None):
|
||||
if None in last_write_to_resource:
|
||||
op._add_control_input(last_write_to_resource[None]) # pylint: disable=protected-access
|
||||
op._add_control_input(last_write_to_resource[None])
|
||||
last_write_to_resource[None] = op
|
||||
control_inputs = [
|
||||
c for c in control_inputs if is_building_function or
|
||||
(c._control_flow_context is op._control_flow_context)] # pylint: disable=protected-access
|
||||
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
|
||||
|
||||
# Ensure ordering of collective ops
|
||||
manager_ids = collective_manager_ids_from_op(op)
|
||||
for manager_id in manager_ids:
|
||||
if manager_id in collective_manager_scopes_opened:
|
||||
# Chain this function call if the scope was opened.
|
||||
op._add_control_input(collective_manager_scopes_opened[manager_id])
|
||||
collective_manager_scopes_opened[manager_id] = op
|
||||
else:
|
||||
# If this op is in a scope not created here, create a chain starting
|
||||
# at this op.
|
||||
if manager_id in collective_manager_scopes_used:
|
||||
op._add_control_input(collective_manager_scopes_used[manager_id])
|
||||
collective_manager_scopes_used[manager_id] = op
|
||||
|
||||
if control_inputs and not is_building_function:
|
||||
control_inputs = [
|
||||
c for c in control_inputs
|
||||
if c._control_flow_context is op._control_flow_context
|
||||
]
|
||||
|
||||
op._add_control_inputs(control_inputs)
|
||||
|
||||
# Ensure all ops which must run do run
|
||||
self.ops_which_must_run.update(ops_which_must_run)
|
||||
for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
|
||||
if self.ops_which_must_run:
|
||||
r.op._add_control_inputs( # pylint: disable=protected-access
|
||||
[
|
||||
o for o in self.ops_which_must_run
|
||||
if r.graph.building_function or
|
||||
(o._control_flow_context is r.op._control_flow_context) # pylint: disable=protected-access
|
||||
])
|
||||
updated_ops_which_must_run = []
|
||||
if r.graph.building_function:
|
||||
updated_ops_which_must_run = self.ops_which_must_run
|
||||
else:
|
||||
updated_ops_which_must_run = [
|
||||
o for o in self.ops_which_must_run
|
||||
if o._control_flow_context is r.op._control_flow_context
|
||||
]
|
||||
r.op._add_control_inputs(updated_ops_which_must_run)
|
||||
|
||||
self.collective_manager_ids_used = collective_manager_scopes_used
|
||||
|
||||
|
||||
_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
|
||||
|
@ -25,6 +25,9 @@ READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
|
||||
RESOURCE_READ_OPS = set()
|
||||
|
||||
|
||||
COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
|
||||
|
||||
|
||||
def register_read_only_resource_op(op_type):
|
||||
"""Declares that `op_type` does not update its touched resource."""
|
||||
RESOURCE_READ_OPS.add(op_type)
|
||||
|
@ -1023,6 +1023,8 @@ def func_graph_from_py_func(name,
|
||||
|
||||
if add_control_dependencies:
|
||||
func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
|
||||
func_graph.collective_manager_ids_used = (
|
||||
deps_control_manager.collective_manager_ids_used)
|
||||
|
||||
return func_graph
|
||||
|
||||
|
@ -1181,6 +1181,9 @@ def partitioned_call(args,
|
||||
outputs = op.outputs
|
||||
if hasattr(f, "graph"):
|
||||
_set_read_only_resource_inputs_attr(op, f.graph)
|
||||
if hasattr(f.graph, "collective_manager_ids_used"):
|
||||
ops.set_int_list_attr(
|
||||
op, acd.COLLECTIVE_MANAGER_IDS, f.graph.collective_manager_ids_used)
|
||||
return outputs if outputs else op
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user