From 0d145969f60ec7a9e61ef0ce33f20304e99ff2c0 Mon Sep 17 00:00:00 2001 From: Akshay Modi <nareshmodi@google.com> Date: Tue, 3 Mar 2020 16:19:30 -0800 Subject: [PATCH] Support ordering of certain collectives. PiperOrigin-RevId: 298719882 Change-Id: I53b83f2b26f431ef7d1258280d51c0604b19b50f --- .../python/framework/auto_control_deps.py | 111 ++++++++++++++---- .../framework/auto_control_deps_utils.py | 3 + tensorflow/python/framework/func_graph.py | 2 + tensorflow/python/ops/functional_ops.py | 3 + 4 files changed, 95 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index d4ef1a1de19..c674717482d 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -129,6 +129,32 @@ class ResourceType(enum.Enum): READ_WRITE = "read-write" +def collective_manager_ids_from_op(op): + """Returns CollectiveManager ID from the op if one exists, else None. + + CollectiveManager adds collective and no_op operations tagged with an ID, + unique to the manager object. This function extracts that ID, or None, if the + node was not generated by a CollectiveManager. + + Args: + op: `Operation` to get the collective manager ID from. + + Returns: + List of CollectiveManager IDs used by the op. + """ + if op.type == "CollectiveReduce": + try: + return [op.get_attr("_collective_manager_id")] + except ValueError: + pass + elif op.type == "StatefulPartitionedCall": + try: + return op.get_attr(utils.COLLECTIVE_MANAGER_IDS) + except ValueError: + pass + return [] + + class AutomaticControlDependencies(object): """Context manager to automatically add control dependencies. @@ -241,6 +267,7 @@ class AutomaticControlDependencies(object): merge_for_resource: map from resource tensor to merge which must follow all usages of it. """ + # pylint: disable=protected-access inp = switch_op.inputs[0] input_id = ops.tensor_id(inp) if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": @@ -250,24 +277,25 @@ class AutomaticControlDependencies(object): output_id = ops.tensor_id(output) if output_id in merge_for_resource: return - new_merge = control_flow_ops.merge(switch_op.outputs, - name="artificial_merge") - new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access - switch_op._control_flow_context.outer_context) # pylint: disable=protected-access + new_merge = control_flow_ops.merge( + switch_op.outputs, name="artificial_merge") + new_merge[0].op._control_flow_context = ( + switch_op._control_flow_context.outer_context) # Ensures the merge always runs ops_which_must_run.add(new_merge[0].op) if input_id in last_write_to_resource: # Ensures the switch executes after the previous op using the resource. - switch_op._add_control_input(last_write_to_resource[input_id]) # pylint: disable=protected-access + switch_op._add_control_input(last_write_to_resource[input_id]) # Ensure the next op outside the cond happens after the merge. last_write_to_resource[input_id] = new_merge[0].op if input_id in merge_for_resource: - merge_for_resource[input_id]._add_control_input(new_merge[0].op) # pylint: disable=protected-access + merge_for_resource[input_id]._add_control_input(new_merge[0].op) for o in switch_op.outputs: # Ensures the merge will execute after all ops inside the cond merge_for_resource[ops.tensor_id(o)] = new_merge[0].op def __exit__(self, unused_type, unused_value, unused_traceback): + # pylint: disable=protected-access if context.executing_eagerly(): return @@ -275,19 +303,24 @@ class AutomaticControlDependencies(object): raise RuntimeError( "Graph changed while trying to add control dependencies.") - # pylint: disable=protected-access if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False - # pylint: enable=protected-access # map from resource tensor to the last op which wrote to it last_write_to_resource = {} # map from resource tensor to the list of reads from it since the last # write or since the beginning of the function. reads_since_last_write_to_resource = collections.defaultdict(list) + # CollectiveManager manager_ids within a particular function call should not + # be needed outside of that function call. So we keep them separate (though + # the general idea of the maps is the same, in the future, we'll need to + # correctly thread the control output outside). + # Map from collective manager scope to the last op which used it + collective_manager_scopes_opened = {} + collective_manager_scopes_used = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource @@ -334,13 +367,20 @@ class AutomaticControlDependencies(object): # TODO(srbs): Do not add functional ops to `ops_which_must_run` if # they only have variable reads and are otherwise stateless. ops_which_must_run.add(op) + # Make a note of all opened manager_ids. + if op.type == "NoOp": + try: + collective_manager_scopes_opened[op.get_attr( + "_collective_manager_id")] = op + except ValueError: + pass # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run if op.type == "Merge": for o in ops_which_must_run: - op._add_control_input(o) # pylint: disable=protected-access + op._add_control_input(o) for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_write_to_resource: @@ -369,12 +409,12 @@ class AutomaticControlDependencies(object): # Ensure uses of resources are serialized if input_id in last_write_to_resource: if is_building_function or ( - last_write_to_resource[input_id]._control_flow_context # pylint: disable=protected-access - is op._control_flow_context): # pylint: disable=protected-access + last_write_to_resource[input_id]._control_flow_context + is op._control_flow_context): control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: - merge_for_resource[input_id]._add_control_input(op) # pylint: disable=protected-access + merge_for_resource[input_id]._add_control_input(op) if is_read: reads_since_last_write_to_resource[input_id].append(op) else: @@ -383,25 +423,48 @@ class AutomaticControlDependencies(object): last_write_to_resource[input_id] = op if (op_is_stateful(op) and not resource_inputs - and op._control_flow_context is None): # pylint: disable=protected-access + and op._control_flow_context is None): if None in last_write_to_resource: - op._add_control_input(last_write_to_resource[None]) # pylint: disable=protected-access + op._add_control_input(last_write_to_resource[None]) last_write_to_resource[None] = op - control_inputs = [ - c for c in control_inputs if is_building_function or - (c._control_flow_context is op._control_flow_context)] # pylint: disable=protected-access - op._add_control_inputs(control_inputs) # pylint: disable=protected-access + + # Ensure ordering of collective ops + manager_ids = collective_manager_ids_from_op(op) + for manager_id in manager_ids: + if manager_id in collective_manager_scopes_opened: + # Chain this function call if the scope was opened. + op._add_control_input(collective_manager_scopes_opened[manager_id]) + collective_manager_scopes_opened[manager_id] = op + else: + # If this op is in a scope not created here, create a chain starting + # at this op. + if manager_id in collective_manager_scopes_used: + op._add_control_input(collective_manager_scopes_used[manager_id]) + collective_manager_scopes_used[manager_id] = op + + if control_inputs and not is_building_function: + control_inputs = [ + c for c in control_inputs + if c._control_flow_context is op._control_flow_context + ] + + op._add_control_inputs(control_inputs) # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) for r in nest.flatten(list(self._returned_tensors), expand_composites=True): if self.ops_which_must_run: - r.op._add_control_inputs( # pylint: disable=protected-access - [ - o for o in self.ops_which_must_run - if r.graph.building_function or - (o._control_flow_context is r.op._control_flow_context) # pylint: disable=protected-access - ]) + updated_ops_which_must_run = [] + if r.graph.building_function: + updated_ops_which_must_run = self.ops_which_must_run + else: + updated_ops_which_must_run = [ + o for o in self.ops_which_must_run + if o._control_flow_context is r.op._control_flow_context + ] + r.op._add_control_inputs(updated_ops_which_must_run) + + self.collective_manager_ids_used = collective_manager_scopes_used _acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers") diff --git a/tensorflow/python/framework/auto_control_deps_utils.py b/tensorflow/python/framework/auto_control_deps_utils.py index f1b23556c7a..63ca73bb034 100644 --- a/tensorflow/python/framework/auto_control_deps_utils.py +++ b/tensorflow/python/framework/auto_control_deps_utils.py @@ -25,6 +25,9 @@ READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs" RESOURCE_READ_OPS = set() +COLLECTIVE_MANAGER_IDS = "_collective_manager_ids" + + def register_read_only_resource_op(op_type): """Declares that `op_type` does not update its touched resource.""" RESOURCE_READ_OPS.add(op_type) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index d686df562a6..d702771cef3 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -1023,6 +1023,8 @@ def func_graph_from_py_func(name, if add_control_dependencies: func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) + func_graph.collective_manager_ids_used = ( + deps_control_manager.collective_manager_ids_used) return func_graph diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index a90f223ac92..63c653b5df1 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -1181,6 +1181,9 @@ def partitioned_call(args, outputs = op.outputs if hasattr(f, "graph"): _set_read_only_resource_inputs_attr(op, f.graph) + if hasattr(f.graph, "collective_manager_ids_used"): + ops.set_int_list_attr( + op, acd.COLLECTIVE_MANAGER_IDS, f.graph.collective_manager_ids_used) return outputs if outputs else op