From 0d145969f60ec7a9e61ef0ce33f20304e99ff2c0 Mon Sep 17 00:00:00 2001
From: Akshay Modi <nareshmodi@google.com>
Date: Tue, 3 Mar 2020 16:19:30 -0800
Subject: [PATCH] Support ordering of certain collectives.

PiperOrigin-RevId: 298719882
Change-Id: I53b83f2b26f431ef7d1258280d51c0604b19b50f
---
 .../python/framework/auto_control_deps.py     | 111 ++++++++++++++----
 .../framework/auto_control_deps_utils.py      |   3 +
 tensorflow/python/framework/func_graph.py     |   2 +
 tensorflow/python/ops/functional_ops.py       |   3 +
 4 files changed, 95 insertions(+), 24 deletions(-)

diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index d4ef1a1de19..c674717482d 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -129,6 +129,32 @@ class ResourceType(enum.Enum):
   READ_WRITE = "read-write"
 
 
+def collective_manager_ids_from_op(op):
+  """Returns CollectiveManager ID from the op if one exists, else None.
+
+  CollectiveManager adds collective and no_op operations tagged with an ID,
+  unique to the manager object. This function extracts that ID, or None, if the
+  node was not generated by a CollectiveManager.
+
+  Args:
+    op: `Operation` to get the collective manager ID from.
+
+  Returns:
+    List of CollectiveManager IDs used by the op.
+  """
+  if op.type == "CollectiveReduce":
+    try:
+      return [op.get_attr("_collective_manager_id")]
+    except ValueError:
+      pass
+  elif op.type == "StatefulPartitionedCall":
+    try:
+      return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
+    except ValueError:
+      pass
+  return []
+
+
 class AutomaticControlDependencies(object):
   """Context manager to automatically add control dependencies.
 
@@ -241,6 +267,7 @@ class AutomaticControlDependencies(object):
       merge_for_resource: map from resource tensor to merge which must follow
         all usages of it.
     """
+    # pylint: disable=protected-access
     inp = switch_op.inputs[0]
     input_id = ops.tensor_id(inp)
     if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
@@ -250,24 +277,25 @@ class AutomaticControlDependencies(object):
     output_id = ops.tensor_id(output)
     if output_id in merge_for_resource:
       return
-    new_merge = control_flow_ops.merge(switch_op.outputs,
-                                       name="artificial_merge")
-    new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
-        switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
+    new_merge = control_flow_ops.merge(
+        switch_op.outputs, name="artificial_merge")
+    new_merge[0].op._control_flow_context = (
+        switch_op._control_flow_context.outer_context)
     # Ensures the merge always runs
     ops_which_must_run.add(new_merge[0].op)
     if input_id in last_write_to_resource:
       # Ensures the switch executes after the previous op using the resource.
-      switch_op._add_control_input(last_write_to_resource[input_id])  # pylint: disable=protected-access
+      switch_op._add_control_input(last_write_to_resource[input_id])
     # Ensure the next op outside the cond happens after the merge.
     last_write_to_resource[input_id] = new_merge[0].op
     if input_id in merge_for_resource:
-      merge_for_resource[input_id]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
+      merge_for_resource[input_id]._add_control_input(new_merge[0].op)
     for o in switch_op.outputs:
       # Ensures the merge will execute after all ops inside the cond
       merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
 
   def __exit__(self, unused_type, unused_value, unused_traceback):
+    # pylint: disable=protected-access
     if context.executing_eagerly():
       return
 
@@ -275,19 +303,24 @@ class AutomaticControlDependencies(object):
       raise RuntimeError(
           "Graph changed while trying to add control dependencies.")
 
-    # pylint: disable=protected-access
     if hasattr(self._graph, "outer_graph"):
       outer_val = self._graph.outer_graph._add_control_dependencies
       self._graph._add_control_dependencies = outer_val
     else:
       self._graph._add_control_dependencies = False
-    # pylint: enable=protected-access
 
     # map from resource tensor to the last op which wrote to it
     last_write_to_resource = {}
     # map from resource tensor to the list of reads from it since the last
     # write or since the beginning of the function.
     reads_since_last_write_to_resource = collections.defaultdict(list)
+    # CollectiveManager manager_ids within a particular function call should not
+    # be needed outside of that function call. So we keep them separate (though
+    # the general idea of the maps is the same, in the future, we'll need to
+    # correctly thread the control output outside).
+    # Map from collective manager scope to the last op which used it
+    collective_manager_scopes_opened = {}
+    collective_manager_scopes_used = {}
     # set of conditional and loop exits
     ops_which_must_run = set()
     # merge which must depend on ops which use this resource
@@ -334,13 +367,20 @@ class AutomaticControlDependencies(object):
         # TODO(srbs): Do not add functional ops to `ops_which_must_run` if
         # they only have variable reads and are otherwise stateless.
         ops_which_must_run.add(op)
+      # Make a note of all opened manager_ids.
+      if op.type == "NoOp":
+        try:
+          collective_manager_scopes_opened[op.get_attr(
+              "_collective_manager_id")] = op
+        except ValueError:
+          pass
       # Ignore switches (they're handled separately)
       if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
         continue
       # Make merges trigger all other computation which must run
       if op.type == "Merge":
         for o in ops_which_must_run:
-          op._add_control_input(o)  # pylint: disable=protected-access
+          op._add_control_input(o)
           for inp in o.inputs:
             input_id = ops.tensor_id(inp)
             if input_id in last_write_to_resource:
@@ -369,12 +409,12 @@ class AutomaticControlDependencies(object):
         # Ensure uses of resources are serialized
         if input_id in last_write_to_resource:
           if is_building_function or (
-              last_write_to_resource[input_id]._control_flow_context  # pylint: disable=protected-access
-              is op._control_flow_context):  # pylint: disable=protected-access
+              last_write_to_resource[input_id]._control_flow_context
+              is op._control_flow_context):
             control_inputs.add(last_write_to_resource[input_id])
         # Ensure merges happen after the closing of a cond block
         if input_id in merge_for_resource:
-          merge_for_resource[input_id]._add_control_input(op)  # pylint: disable=protected-access
+          merge_for_resource[input_id]._add_control_input(op)
         if is_read:
           reads_since_last_write_to_resource[input_id].append(op)
         else:
@@ -383,25 +423,48 @@ class AutomaticControlDependencies(object):
           last_write_to_resource[input_id] = op
 
       if (op_is_stateful(op) and not resource_inputs
-          and op._control_flow_context is None):  # pylint: disable=protected-access
+          and op._control_flow_context is None):
         if None in last_write_to_resource:
-          op._add_control_input(last_write_to_resource[None])  # pylint: disable=protected-access
+          op._add_control_input(last_write_to_resource[None])
         last_write_to_resource[None] = op
-      control_inputs = [
-          c for c in control_inputs if is_building_function or
-          (c._control_flow_context is op._control_flow_context)]  # pylint: disable=protected-access
-      op._add_control_inputs(control_inputs)  # pylint: disable=protected-access
+
+      # Ensure ordering of collective ops
+      manager_ids = collective_manager_ids_from_op(op)
+      for manager_id in manager_ids:
+        if manager_id in collective_manager_scopes_opened:
+          # Chain this function call if the scope was opened.
+          op._add_control_input(collective_manager_scopes_opened[manager_id])
+          collective_manager_scopes_opened[manager_id] = op
+        else:
+          # If this op is in a scope not created here, create a chain starting
+          # at this op.
+          if manager_id in collective_manager_scopes_used:
+            op._add_control_input(collective_manager_scopes_used[manager_id])
+          collective_manager_scopes_used[manager_id] = op
+
+      if control_inputs and not is_building_function:
+        control_inputs = [
+            c for c in control_inputs
+            if c._control_flow_context is op._control_flow_context
+        ]
+
+      op._add_control_inputs(control_inputs)
 
     # Ensure all ops which must run do run
     self.ops_which_must_run.update(ops_which_must_run)
     for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
       if self.ops_which_must_run:
-        r.op._add_control_inputs(  # pylint: disable=protected-access
-            [
-                o for o in self.ops_which_must_run
-                if r.graph.building_function or
-                (o._control_flow_context is r.op._control_flow_context)  # pylint: disable=protected-access
-            ])
+        updated_ops_which_must_run = []
+        if r.graph.building_function:
+          updated_ops_which_must_run = self.ops_which_must_run
+        else:
+          updated_ops_which_must_run = [
+              o for o in self.ops_which_must_run
+              if o._control_flow_context is r.op._control_flow_context
+          ]
+        r.op._add_control_inputs(updated_ops_which_must_run)
+
+    self.collective_manager_ids_used = collective_manager_scopes_used
 
 
 _acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
diff --git a/tensorflow/python/framework/auto_control_deps_utils.py b/tensorflow/python/framework/auto_control_deps_utils.py
index f1b23556c7a..63ca73bb034 100644
--- a/tensorflow/python/framework/auto_control_deps_utils.py
+++ b/tensorflow/python/framework/auto_control_deps_utils.py
@@ -25,6 +25,9 @@ READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
 RESOURCE_READ_OPS = set()
 
 
+COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
+
+
 def register_read_only_resource_op(op_type):
   """Declares that `op_type` does not update its touched resource."""
   RESOURCE_READ_OPS.add(op_type)
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index d686df562a6..d702771cef3 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -1023,6 +1023,8 @@ def func_graph_from_py_func(name,
 
   if add_control_dependencies:
     func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
+    func_graph.collective_manager_ids_used = (
+        deps_control_manager.collective_manager_ids_used)
 
   return func_graph
 
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a90f223ac92..63c653b5df1 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -1181,6 +1181,9 @@ def partitioned_call(args,
   outputs = op.outputs
   if hasattr(f, "graph"):
     _set_read_only_resource_inputs_attr(op, f.graph)
+    if hasattr(f.graph, "collective_manager_ids_used"):
+      ops.set_int_list_attr(
+          op, acd.COLLECTIVE_MANAGER_IDS, f.graph.collective_manager_ids_used)
   return outputs if outputs else op