Support ordering of certain collectives.

PiperOrigin-RevId: 298719882
Change-Id: I53b83f2b26f431ef7d1258280d51c0604b19b50f
This commit is contained in:
Akshay Modi 2020-03-03 16:19:30 -08:00 committed by TensorFlower Gardener
parent 4614140908
commit 0d145969f6
4 changed files with 95 additions and 24 deletions

View File

@ -129,6 +129,32 @@ class ResourceType(enum.Enum):
READ_WRITE = "read-write"
def collective_manager_ids_from_op(op):
"""Returns CollectiveManager ID from the op if one exists, else None.
CollectiveManager adds collective and no_op operations tagged with an ID,
unique to the manager object. This function extracts that ID, or None, if the
node was not generated by a CollectiveManager.
Args:
op: `Operation` to get the collective manager ID from.
Returns:
List of CollectiveManager IDs used by the op.
"""
if op.type == "CollectiveReduce":
try:
return [op.get_attr("_collective_manager_id")]
except ValueError:
pass
elif op.type == "StatefulPartitionedCall":
try:
return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
except ValueError:
pass
return []
class AutomaticControlDependencies(object):
"""Context manager to automatically add control dependencies.
@ -241,6 +267,7 @@ class AutomaticControlDependencies(object):
merge_for_resource: map from resource tensor to merge which must follow
all usages of it.
"""
# pylint: disable=protected-access
inp = switch_op.inputs[0]
input_id = ops.tensor_id(inp)
if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
@ -250,24 +277,25 @@ class AutomaticControlDependencies(object):
output_id = ops.tensor_id(output)
if output_id in merge_for_resource:
return
new_merge = control_flow_ops.merge(switch_op.outputs,
name="artificial_merge")
new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access
switch_op._control_flow_context.outer_context) # pylint: disable=protected-access
new_merge = control_flow_ops.merge(
switch_op.outputs, name="artificial_merge")
new_merge[0].op._control_flow_context = (
switch_op._control_flow_context.outer_context)
# Ensures the merge always runs
ops_which_must_run.add(new_merge[0].op)
if input_id in last_write_to_resource:
# Ensures the switch executes after the previous op using the resource.
switch_op._add_control_input(last_write_to_resource[input_id]) # pylint: disable=protected-access
switch_op._add_control_input(last_write_to_resource[input_id])
# Ensure the next op outside the cond happens after the merge.
last_write_to_resource[input_id] = new_merge[0].op
if input_id in merge_for_resource:
merge_for_resource[input_id]._add_control_input(new_merge[0].op) # pylint: disable=protected-access
merge_for_resource[input_id]._add_control_input(new_merge[0].op)
for o in switch_op.outputs:
# Ensures the merge will execute after all ops inside the cond
merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
def __exit__(self, unused_type, unused_value, unused_traceback):
# pylint: disable=protected-access
if context.executing_eagerly():
return
@ -275,19 +303,24 @@ class AutomaticControlDependencies(object):
raise RuntimeError(
"Graph changed while trying to add control dependencies.")
# pylint: disable=protected-access
if hasattr(self._graph, "outer_graph"):
outer_val = self._graph.outer_graph._add_control_dependencies
self._graph._add_control_dependencies = outer_val
else:
self._graph._add_control_dependencies = False
# pylint: enable=protected-access
# map from resource tensor to the last op which wrote to it
last_write_to_resource = {}
# map from resource tensor to the list of reads from it since the last
# write or since the beginning of the function.
reads_since_last_write_to_resource = collections.defaultdict(list)
# CollectiveManager manager_ids within a particular function call should not
# be needed outside of that function call. So we keep them separate (though
# the general idea of the maps is the same, in the future, we'll need to
# correctly thread the control output outside).
# Map from collective manager scope to the last op which used it
collective_manager_scopes_opened = {}
collective_manager_scopes_used = {}
# set of conditional and loop exits
ops_which_must_run = set()
# merge which must depend on ops which use this resource
@ -334,13 +367,20 @@ class AutomaticControlDependencies(object):
# TODO(srbs): Do not add functional ops to `ops_which_must_run` if
# they only have variable reads and are otherwise stateless.
ops_which_must_run.add(op)
# Make a note of all opened manager_ids.
if op.type == "NoOp":
try:
collective_manager_scopes_opened[op.get_attr(
"_collective_manager_id")] = op
except ValueError:
pass
# Ignore switches (they're handled separately)
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
continue
# Make merges trigger all other computation which must run
if op.type == "Merge":
for o in ops_which_must_run:
op._add_control_input(o) # pylint: disable=protected-access
op._add_control_input(o)
for inp in o.inputs:
input_id = ops.tensor_id(inp)
if input_id in last_write_to_resource:
@ -369,12 +409,12 @@ class AutomaticControlDependencies(object):
# Ensure uses of resources are serialized
if input_id in last_write_to_resource:
if is_building_function or (
last_write_to_resource[input_id]._control_flow_context # pylint: disable=protected-access
is op._control_flow_context): # pylint: disable=protected-access
last_write_to_resource[input_id]._control_flow_context
is op._control_flow_context):
control_inputs.add(last_write_to_resource[input_id])
# Ensure merges happen after the closing of a cond block
if input_id in merge_for_resource:
merge_for_resource[input_id]._add_control_input(op) # pylint: disable=protected-access
merge_for_resource[input_id]._add_control_input(op)
if is_read:
reads_since_last_write_to_resource[input_id].append(op)
else:
@ -383,25 +423,48 @@ class AutomaticControlDependencies(object):
last_write_to_resource[input_id] = op
if (op_is_stateful(op) and not resource_inputs
and op._control_flow_context is None): # pylint: disable=protected-access
and op._control_flow_context is None):
if None in last_write_to_resource:
op._add_control_input(last_write_to_resource[None]) # pylint: disable=protected-access
op._add_control_input(last_write_to_resource[None])
last_write_to_resource[None] = op
control_inputs = [
c for c in control_inputs if is_building_function or
(c._control_flow_context is op._control_flow_context)] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
# Ensure ordering of collective ops
manager_ids = collective_manager_ids_from_op(op)
for manager_id in manager_ids:
if manager_id in collective_manager_scopes_opened:
# Chain this function call if the scope was opened.
op._add_control_input(collective_manager_scopes_opened[manager_id])
collective_manager_scopes_opened[manager_id] = op
else:
# If this op is in a scope not created here, create a chain starting
# at this op.
if manager_id in collective_manager_scopes_used:
op._add_control_input(collective_manager_scopes_used[manager_id])
collective_manager_scopes_used[manager_id] = op
if control_inputs and not is_building_function:
control_inputs = [
c for c in control_inputs
if c._control_flow_context is op._control_flow_context
]
op._add_control_inputs(control_inputs)
# Ensure all ops which must run do run
self.ops_which_must_run.update(ops_which_must_run)
for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
if self.ops_which_must_run:
r.op._add_control_inputs( # pylint: disable=protected-access
[
o for o in self.ops_which_must_run
if r.graph.building_function or
(o._control_flow_context is r.op._control_flow_context) # pylint: disable=protected-access
])
updated_ops_which_must_run = []
if r.graph.building_function:
updated_ops_which_must_run = self.ops_which_must_run
else:
updated_ops_which_must_run = [
o for o in self.ops_which_must_run
if o._control_flow_context is r.op._control_flow_context
]
r.op._add_control_inputs(updated_ops_which_must_run)
self.collective_manager_ids_used = collective_manager_scopes_used
_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")

View File

@ -25,6 +25,9 @@ READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
RESOURCE_READ_OPS = set()
COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
def register_read_only_resource_op(op_type):
"""Declares that `op_type` does not update its touched resource."""
RESOURCE_READ_OPS.add(op_type)

View File

@ -1023,6 +1023,8 @@ def func_graph_from_py_func(name,
if add_control_dependencies:
func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
func_graph.collective_manager_ids_used = (
deps_control_manager.collective_manager_ids_used)
return func_graph

View File

@ -1181,6 +1181,9 @@ def partitioned_call(args,
outputs = op.outputs
if hasattr(f, "graph"):
_set_read_only_resource_inputs_attr(op, f.graph)
if hasattr(f.graph, "collective_manager_ids_used"):
ops.set_int_list_attr(
op, acd.COLLECTIVE_MANAGER_IDS, f.graph.collective_manager_ids_used)
return outputs if outputs else op