diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 2ff8c897c80..efa10d6ee45 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -63,6 +63,7 @@ py_library(
     srcs = ["cross_device_ops.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":collective_util",
         ":cross_device_utils",
         ":device_util",
         ":reduce_util",
@@ -97,6 +98,7 @@ py_library(
         "//tensorflow/python:gradients",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:nccl_ops",
+        "//tensorflow/python:platform",
     ],
 )
 
@@ -145,6 +147,7 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":collective_util",
         ":device_util",
         ":numpy_dataset",
         ":reduce_util",
@@ -580,6 +583,15 @@ py_library(
     ],
 )
 
+py_library(
+    name = "collective_util",
+    srcs = ["collective_util.py"],
+    deps = [
+        "//tensorflow/python:util",
+        "//tensorflow/python:variable_scope",
+    ],
+)
+
 py_library(
     name = "shared_variable_creator",
     srcs = ["shared_variable_creator.py"],
@@ -795,7 +807,9 @@ cuda_py_test(
     name = "cross_device_utils_test",
     srcs = ["cross_device_utils_test.py"],
     deps = [
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python/distribute:combinations",
@@ -815,6 +829,7 @@ cuda_py_test(
     ],
     deps = [
         ":collective_all_reduce_strategy",
+        ":collective_util",
         ":mirrored_strategy",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index e77fce0f2a5..ab9721e1bfb 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -95,6 +95,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
         TFConfigClusterResolver which is instantiated from the TF_CONFIG env
         var.
     """
+    # TODO(b/150151677): consider move communication to CollectiveHints.
     super(CollectiveAllReduceStrategy, self).__init__(
         CollectiveAllReduceExtended(
             self,
@@ -505,7 +506,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
 
     return updated_config
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.Mirrored) and
         reduce_op == reduce_util.ReduceOp.MEAN):
       return value
@@ -526,7 +527,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
       return cross_device_ops_lib.reduce_non_distributed_value(
           reduce_op, value, destinations, len(self.worker_devices))
     return self._get_cross_device_ops().reduce(
-        reduce_op, value, destinations=destinations)
+        reduce_op,
+        value,
+        destinations=destinations,
+        experimental_hints=experimental_hints)
 
   def _warn_nccl_no_gpu(self):
     if ((self._communication ==
diff --git a/tensorflow/python/distribute/collective_util.py b/tensorflow/python/distribute/collective_util.py
new file mode 100644
index 00000000000..fb7008d1636
--- /dev/null
+++ b/tensorflow/python/distribute/collective_util.py
@@ -0,0 +1,63 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for collectives."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("distribute.experimental.CollectiveHints")
+class Hints(object):
+  """Hints for collective operations like AllReduce.
+
+  This can be passed to methods like
+  `tf.distribute.get_replica_context().all_reduce()` to optimize collective
+  operation performance. Note that these are only hints, which may or may not
+  change the actual behavior. Some options only apply to certain strategy and
+  are ignored by others.
+
+  One common optimization is to break gradients all-reduce into multiple packs
+  so that weight updates can overlap with gradient all-reduce.
+
+  Example:
+
+  ```python
+  hints = tf.distribute.experimental.CollectiveHints(
+      bytes_per_pack=50 * 1024 * 1024)
+  grads = tf.distribute.get_replica_context().all_reduce(
+      'sum', grads, experimental_hints=hints)
+  optimizer.apply_gradients(zip(grads, vars), all_reduce_sum_gradients=False)
+  ```
+
+  """
+
+  def __init__(self, bytes_per_pack=0):
+    """Creates a CollectiveHints.
+
+    Args:
+      bytes_per_pack: A non-negative integer. Breaks collective operations into
+        packs of certain size. If it's zero, the value is determined
+        automatically. This only applies to all-reduce with
+        `MultiWorkerMirroredStrategy` currently.
+
+    Raises:
+      ValueError: When arguments have invalid value.
+    """
+    if bytes_per_pack < 0:
+      raise ValueError("bytes_per_pack must be non-negative")
+    self.bytes_per_pack = bytes_per_pack
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 2f2a105df84..eeb45428bab 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -19,11 +19,12 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import enum
 
+import enum
 import six
 
 from tensorflow.python.client import device_lib
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import cross_device_utils
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import reduce_util
@@ -222,7 +223,11 @@ class CrossDeviceOps(object):
     # Returns 1 by default, the value may be overridden by sub classes.
     return 1
 
-  def reduce(self, reduce_op, per_replica_value, destinations):
+  def reduce(self,
+             reduce_op,
+             per_replica_value,
+             destinations,
+             experimental_hints=None):
     """Reduce `per_replica_value` to `destinations`.
 
     It runs the reduction operation defined by `reduce_op` and put the
@@ -231,8 +236,10 @@ class CrossDeviceOps(object):
     Args:
       reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
         per_replica_value will be reduced.
-      per_replica_value: a PerReplica object or a tensor with device set.
+      per_replica_value: A PerReplica object or a tensor with device set.
       destinations: the reduction destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a Mirrored object.
@@ -254,10 +261,15 @@ class CrossDeviceOps(object):
           per_replica_value.values,
           wrap_class=value_lib.Mirrored)
 
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
     return self.reduce_implementation(reduce_op, per_replica_value,
-                                      destinations)
+                                      destinations, experimental_hints)
 
-  def batch_reduce(self, reduce_op, value_destination_pairs):
+  def batch_reduce(self,
+                   reduce_op,
+                   value_destination_pairs,
+                   experimental_hints=None):
     """Reduce PerReplica objects in a batch.
 
     Reduce each first element in `value_destination_pairs` to each second
@@ -267,10 +279,12 @@ class CrossDeviceOps(object):
     fuse several tensors into one or multiple packs before reduction.
 
     Args:
-      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
-        the `per_replica_value` will be reduced.
-      value_destination_pairs: a list or a tuple of PerReplica objects
-        (or tensors with device set if there is one device) and destinations.
+      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
+        `per_replica_value` will be reduced.
+      value_destination_pairs: A list or a tuple of PerReplica objects (or
+        tensors with device set if there is one device) and destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a list of Mirrored objects.
@@ -299,7 +313,10 @@ class CrossDeviceOps(object):
           for v, _ in value_destination_pairs
       ]
 
-    return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
+    return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
+                                            experimental_hints)
 
   def broadcast(self, tensor, destinations):
     """Broadcast the `tensor` to destinations.
@@ -315,7 +332,8 @@ class CrossDeviceOps(object):
     return self.broadcast_implementation(tensor, destinations)
 
   @doc_controls.for_subclass_implementers
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
     """The implementation of reduce of `per_replica_value` to `destinations`.
 
     Overriding this method is useful for subclass implementers.
@@ -326,8 +344,10 @@ class CrossDeviceOps(object):
     Args:
       reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how
         per_replica_value will be reduced.
-      per_replica_value: a PerReplica object or a tensor with device set.
+      per_replica_value: A PerReplica object or a tensor with device set.
       destinations: the reduction destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a Mirrored object.
@@ -340,7 +360,8 @@ class CrossDeviceOps(object):
         "_reduce method must be implemented in descendants.")
 
   @doc_controls.for_subclass_implementers
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     """Implementation of reduce PerReplica objects in a batch.
 
     Overriding this method is useful for subclass implementers.
@@ -351,8 +372,10 @@ class CrossDeviceOps(object):
     Args:
       reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
         per_replica_value will be reduced.
-      value_destination_pairs: an iterable of tuples of PerReplica objects
+      value_destination_pairs: An iterable of tuples of PerReplica objects
         (or tensors with device set if there is one device) and destinations.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       a list of Mirrored objects.
@@ -362,7 +385,8 @@ class CrossDeviceOps(object):
         tuples of PerReplica objects and destinations
     """
     raise NotImplementedError(
-        "_batch_reduce method must be implemented in descendants.")
+        "batch_reduce_implementation method must be implemented in descendants."
+    )
 
   @doc_controls.for_subclass_implementers
   def broadcast_implementation(self, tensor, destinations):
@@ -403,7 +427,9 @@ class ReductionToOneDevice(CrossDeviceOps):
     self.accumulation_fn = accumulation_fn or math_ops.add_n
     super(ReductionToOneDevice, self).__init__()
 
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
+    del experimental_hints  # Unused.
     if check_destinations(destinations):
       devices = get_devices_from(destinations)
     else:
@@ -416,9 +442,11 @@ class ReductionToOneDevice(CrossDeviceOps):
                              self.accumulation_fn, reduce_op)
     return self.broadcast(reduced, destinations)
 
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     return [
-        self.reduce_implementation(reduce_op, t, destinations=v)
+        self.reduce_implementation(
+            reduce_op, t, destinations=v, experimental_hints=experimental_hints)
         for t, v in value_destination_pairs
     ]
 
@@ -626,21 +654,24 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
     self._simple_cross_replica_ops = ReductionToOneDevice()
     super(AllReduceCrossDeviceOps, self).__init__()
 
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
+    del experimental_hints  # Unused.
     if _devices_match(per_replica_value, destinations):
       return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
     else:
       return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
                                                    destinations)
 
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     if _all_devices_match(value_destination_pairs):
       return self._batch_all_reduce(reduce_op,
                                     [v[0] for v in value_destination_pairs])
     else:
       return [
-          self.reduce_implementation(reduce_op, t, destinations=v)
-          for t, v in value_destination_pairs
+          self.reduce_implementation(reduce_op, value, dest, experimental_hints)
+          for value, dest in value_destination_pairs
       ]
 
   def _batch_all_reduce(self, reduce_op, per_replica_values):
@@ -904,7 +935,6 @@ class CollectiveAllReduce(CrossDeviceOps):
   def __init__(self,
                num_workers=1,
                num_gpus_per_worker=0,
-               num_packs=1,
                collective_keys=None,
                communication=CollectiveCommunication.AUTO):
     """Initializes the object.
@@ -912,13 +942,11 @@ class CollectiveAllReduce(CrossDeviceOps):
     Args:
       num_workers: number of workers in the between-graph replicated training.
       num_gpus_per_worker: number of GPUs per worker.
-      num_packs: gradients will be packed into `num_packs` chunks.
       collective_keys: an optional CollectiveKey object.
       communication: indicates which collective communication to use.
     """
     self._num_workers = num_workers
     self._num_gpus_per_worker = num_gpus_per_worker
-    self._num_packs = num_packs
     self._collective_keys = (collective_keys or
                              cross_device_utils.CollectiveKeys())
     self._communication = communication
@@ -928,8 +956,10 @@ class CollectiveAllReduce(CrossDeviceOps):
   def _num_between_graph_workers(self):
     return self._num_workers
 
-  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
-    all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
+  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
+                            experimental_hints):
+    all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
+                                         experimental_hints)[0]
     devices = get_devices_from(destinations)
 
     if (isinstance(all_reduced, value_lib.Mirrored) and
@@ -958,11 +988,13 @@ class CollectiveAllReduce(CrossDeviceOps):
             index.append(array_ops.identity(all_reduced._primary))  # pylint: disable=protected-access
     return value_lib.regroup(index, wrap_class=value_lib.Mirrored)
 
-  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
+  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
+                                  experimental_hints):
     all_devices_match = _all_devices_match(value_destination_pairs)
     if all_devices_match:
       return self._batch_all_reduce(reduce_op,
-                                    [v[0] for v in value_destination_pairs])
+                                    [v[0] for v in value_destination_pairs],
+                                    experimental_hints)
     else:
       if not all_devices_match:
         logging.log_first_n(
@@ -970,47 +1002,18 @@ class CollectiveAllReduce(CrossDeviceOps):
             "destinations are different.", 10)
 
       return [
-          self.reduce_implementation(reduce_op, t, destinations=v)
-          for t, v in value_destination_pairs
+          self.reduce_implementation(reduce_op, value, dest, experimental_hints)
+          for value, dest in value_destination_pairs
       ]
 
-  def _make_gradient_chunks(self, per_replica_values, num_packs):
-    """Make `per_replica_values` into chunks."""
-    chunked_by_device = _group_value_by_device(per_replica_values)
-    chunked_by_var = list(zip(*chunked_by_device))
-    # chunked_by_var is chunked by variables and takes the following format:
-    # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
-    #  ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
-    #  ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
-    #  ...
-    # ]
-
-    # No chunking if number of variables is fewer than number of packs.
-    if len(chunked_by_var) < num_packs:
-      return [chunked_by_var]
-
-    # First n-1 chunks get `chunk_size` grads, last chunk gets leftover grads.
-    # This strategy can cause the last chunk to have larger size compared to the
-    # first n-1 chunks.  Alternatively, we can increment chunk_size by 1 to get
-    # slightly larger first n-1 chunks and smaller last chunk.
-    # TODO(ayushd): compare different packing strategies.
-    chunk_size = len(chunked_by_var) // num_packs
-    leftover_size = len(chunked_by_var) - chunk_size * (num_packs - 1)
-    assert leftover_size > 0
-    chunked_gv = [
-        chunked_by_var[x:x + chunk_size]
-        for x in range(0, len(chunked_by_var) - leftover_size, chunk_size)
-    ]
-    chunked_gv.append(chunked_by_var[-leftover_size:])
-
-    return chunked_gv
-
-  def _batch_all_reduce(self, reduce_op, per_replica_values):
+  def _batch_all_reduce(self, reduce_op, per_replica_values,
+                        experimental_hints):
     """All reduce algorithm in a batch."""
     dense_values, dense_indices, sparse_values, sparse_indices = (
         cross_device_utils.split_by_sparsity(per_replica_values))
     if dense_values:
-      dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values)
+      dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
+                                                      experimental_hints)
     else:
       dense_results = []
     if sparse_values:
@@ -1018,83 +1021,84 @@ class CollectiveAllReduce(CrossDeviceOps):
                                                         sparse_values)
     else:
       sparse_results = []
-    return cross_device_utils.stitch_values(((dense_results, dense_indices),
-                                             (sparse_results, sparse_indices)))
+    return cross_device_utils.stitch_values(
+        ((dense_results, dense_indices), (sparse_results, sparse_indices)))
 
-  def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values):
+  def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
+                                 experimental_hints):
     """All-reduce across all workers in a batch."""
 
-    chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
-    # Actual number of packs may be different from `self._num_packs`.  e.g. if
-    # there are fewer tensors than `self._num_packs`.
-    num_actual_packs = len(chunked_gv)
-
     batch_size = len(per_replica_values)
     # Pass self._communication to the runtime as a communication hint.
-    communication_hint = self._communication.value
+    communication = self._communication.value
     # For now, we use NCCL only when batch_size > 1.
     # TODO(b/132575814): switch to NCCL for all collectives when communication
     # is NCCL.
     if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
-      communication_hint = CollectiveCommunication.AUTO.value
+      communication = CollectiveCommunication.AUTO.value
+
+    # Reverse the lists so that there's better chance that values follows
+    # the order in which they are calculated (e.g. when they're gradients), so
+    # as to overlap calculation with communication. However, this may not be
+    # optimal for cases like gradients of complicated non-sequential models.
+    #
+    # Note that we reverse the list before packing so that the first pack won't
+    # be too small, since it's more likely for first few packs to have long
+    # queuing time due to concurrent intense computation.
+    #
+    # TODO(b/147393503): explore solutions for optimal ordering.
+    packs = cross_device_utils.pack_by_size(
+        list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
 
     if batch_size > 1:
       logging.info(
           "Collective batch_all_reduce: %d all-reduces, num_workers = %d, "
-          "communication_hint = %s, num_packs = %d" %
-          (batch_size, self._num_workers, communication_hint, num_actual_packs))
+          "communication_hint = %s, num_packs = %d", batch_size,
+          self._num_workers, communication, len(packs))
     else:
       logging.log_first_n(
           logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
           "num_workers = %d, communication_hint = %s, num_packs = %d" %
-          (batch_size, self._num_workers, communication_hint, num_actual_packs),
-          10)
+          (batch_size, self._num_workers, communication, len(packs)), 10)
 
     def batch_fn():
       """Wrapper function around batched all-reduce calls."""
-      reduced_gv_list = []
-      # Reverse the gradient lists so that the gradient grouping roughly follows
-      # the order in which gradients are calculated in backprop.  This should
-      # enable overlapping gradient all-reduce with backprop for most models.
-      # However, it is likely that for some complicated non-sequential models
-      # this grouping is not optimal.
-      #
-      # TODO(b/147393503): explore solutions for optimal gradient grouping.
-      for chunk in reversed(chunked_gv):
-        # By placing all CollectiveReduce ops in a chunk under single name
-        # scope, we ensure they will be picked up by the `ScopedAllocator`
-        # grappler optimizer and packed into a single all-reduce.
+      reduced_values = []
+      for pack in packs:
+        # By placing all CollectiveReduce ops in a pack under single name scope,
+        # we ensure they will be picked up by the `ScopedAllocator` grappler
+        # optimizer and packed into a single all-reduce.
         with ops.name_scope("allreduce"):
-          for grad_and_vars in reversed(chunk):
-            # Gradients for the same variable but from different devices.
-            grads = [g for g, _ in grad_and_vars]
+          for per_replica in pack:
             # Add control dependencies per device from the last gradients to the
             # current set, in order to serialize NCCL launches.
-            if (communication_hint == CollectiveCommunication.NCCL.value and
-                reduced_gv_list):
-              control_input_grads = [g for g, _ in reduced_gv_list[-1]]
+            if (communication == CollectiveCommunication.NCCL.value and
+                reduced_values):
+              control_inputs = [g for g in reduced_values[-1]]
             else:
-              control_input_grads = None
-            collective_reduced = cross_device_utils.build_collective_reduce(
-                grads, self._num_workers, self._collective_keys, "Add", "Id",
-                communication_hint, control_input_grads)
-            result = []
-            for (_, v), g in zip(grad_and_vars, collective_reduced):
-              result.append([g, v])
-            reduced_gv_list.append(result)
-      # Reverse the batch reduced gradients to (approximately) recover the order
-      # in the input per_replica_values.
-      reduced_gv_list.reverse()
-      return reduced_gv_list
+              control_inputs = None
+            reduced_values.append(
+                cross_device_utils.build_collective_reduce(
+                    per_replica.values, self._num_workers,
+                    self._collective_keys, "Add", "Id", communication,
+                    control_inputs))
+      return reduced_values
+
     if context.executing_eagerly():
       batch_fn = def_function.function(batch_fn)
 
-    new_device_grads = [list(x) for x in zip(*batch_fn())]
-    return _ungroup_and_make_mirrored(
-        new_device_grads,
-        per_replica_values[0],
-        reduce_op,
-        num_between_graph_workers=self._num_workers)
+    reduced_values = batch_fn()
+    mirrored = []
+    # Reverse the order of reduced value to recover the order in the input.
+    for value in reversed(reduced_values):
+      if reduce_op == reduce_util.ReduceOp.MEAN:
+        # Assume each worker has the same number of replicas.
+        num_replicas = len(value) * self._num_workers
+        for i, v in enumerate(value):
+          with ops.device(v.device):
+            value[i] = v / num_replicas
+      mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
+    return mirrored
 
   def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
     """All-reduce IndexedSlices across all workers in a batch."""
@@ -1106,48 +1110,31 @@ class CollectiveAllReduce(CrossDeviceOps):
 
     # Pass self._communication to the runtime as a communication hint.
     communication_hint = self._communication.value
-    # For now, we use NCCL only when batch_size > 1 and num_packs is 1.
-    # TODO(b/132575814): Enable NCCL if num_packs > 1.
-    # TODO(b/132575814): Switch to NCCL for all collectives when communication
+    # For now, we use NCCL only when batch_size > 1.
+    # TODO(b/132575814): switch to NCCL for all collectives when communication
     # is NCCL.
-    if self._communication == CollectiveCommunication.NCCL and (
-        len(per_replica_values) == 1 or self._num_packs != 1):
+    if self._communication == CollectiveCommunication.NCCL and len(
+        per_replica_values) == 1:
       communication_hint = CollectiveCommunication.AUTO.value
 
-    chunked_gv = self._make_gradient_chunks(per_replica_values, self._num_packs)
+    gathered_values = []
+    with ops.name_scope("allreduce"):
+      for per_replica in per_replica_values:
+        gathered_values.append(
+            cross_device_utils.build_collective_gather_indexed_slices(
+                per_replica.values, self._num_workers, self._collective_keys,
+                communication_hint))
 
-    reduced_gv_list = []
-    for chunk in chunked_gv:
-      # By placing all CollectiveReduce ops in a chunk under single name scope,
-      # we ensure they will be picked up by the `ScopedAllocator` grappler
-      # optimizer and packed into a single all-reduce.
-      with ops.name_scope("allreduce"):
-        for grad_and_vars in chunk:
-          grads = [g for g, _ in grad_and_vars]
-
-          # Add control dependencies per device from the last gradients to the
-          # current set, in order to serialize NCCL launches.
-          if (communication_hint == CollectiveCommunication.NCCL.value and
-              reduced_gv_list):
-            control_input_grads = [g for g, _ in reduced_gv_list[-1]]
-          else:
-            control_input_grads = None
-
-          collective_reduced = (
-              cross_device_utils.build_collective_gather_indexed_slices(
-                  grads, self._num_workers, self._collective_keys,
-                  communication_hint, control_input_grads))
-          result = []
-          for (_, v), g in zip(grad_and_vars, collective_reduced):
-            result.append([g, v])
-          reduced_gv_list.append(result)
-
-    new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
-    return _ungroup_and_make_mirrored(
-        new_device_grads,
-        per_replica_values[0],
-        reduce_op,
-        num_between_graph_workers=self._num_workers)
+    mirrored = []
+    for value in gathered_values:
+      if reduce_op == reduce_util.ReduceOp.MEAN:
+        # Assume each worker has the same number of replicas.
+        num_replicas = len(value) * self._num_workers
+        for i, v in enumerate(value):
+          with ops.device(v.device):
+            value[i].values = value[i].values / num_replicas
+      mirrored.append(value_lib.regroup(value, wrap_class=value_lib.Mirrored))
+    return mirrored
 
 
 def choose_the_best(devices, session_config=None):
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 3a9d7b7ec44..aba5316f7b5 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
 import numpy as np
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.distribute import collective_all_reduce_strategy
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
 from tensorflow.python.distribute import cross_device_utils
@@ -463,8 +464,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
                         num_gpus=0,
                         communication=CollectiveCommunication.AUTO,
                         use_strategy_object=False,
-                        local_mode=False,
-                        num_packs=1):
+                        local_mode=False):
     collective_keys = cross_device_utils.CollectiveKeys(
         group_key_start=10 + CollectiveAllReduceTest.collective_key_base,
         op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base,
@@ -487,7 +487,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
             1,
             num_gpus,
             collective_keys=collective_keys,
-            num_packs=num_packs,
             communication=communication)
         return collective_all_reduce_ops, devices, ""
     else:
@@ -520,7 +519,6 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
             NUM_WORKERS,
             num_gpus,
             collective_keys=collective_keys,
-            num_packs=num_packs,
             communication=communication)
         return (collective_all_reduce_ops, devices,
                 "grpc://" + self._cluster_spec[task_type][task_id])
@@ -532,15 +530,14 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
                       communication,
                       use_strategy_object=False,
                       local_mode=False,
-                      num_packs=1):
+                      hints=None):
     collective_all_reduce, devices, master_target = self._get_test_objects(
         task_type,
         task_id,
         num_gpus,
         communication=communication,
         use_strategy_object=use_strategy_object,
-        local_mode=local_mode,
-        num_packs=num_packs)
+        local_mode=local_mode)
     if local_mode:
       num_workers = 1
       worker_device = None
@@ -553,17 +550,19 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
       if use_strategy_object:
         with test_object.scope():
           return test_object.extended.reduce_to(reduce_op, per_replica,
-                                                destinations)
+                                                destinations, hints)
       else:
-        return test_object.reduce(reduce_op, per_replica, destinations)
+        return test_object.reduce(reduce_op, per_replica, destinations, hints)
 
     def _batch_reduce(test_object, reduce_op, value_destination_pairs):
       if use_strategy_object:
         with test_object.scope():
           return test_object.extended.batch_reduce_to(reduce_op,
-                                                      value_destination_pairs)
+                                                      value_destination_pairs,
+                                                      hints)
       else:
-        return test_object.batch_reduce(reduce_op, value_destination_pairs)
+        return test_object.batch_reduce(reduce_op, value_destination_pairs,
+                                        hints)
 
     with ops.Graph().as_default(), \
          ops.device(worker_device), \
@@ -724,16 +723,17 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
           mode=["graph"],
           required_gpus=[0, 1, 2],
           use_strategy_object=[True, False],
-          num_packs=[1, 2]))
+          bytes_per_pack=[0, 1, 4]))
   def testReductionDistributed(self, required_gpus, use_strategy_object,
-                               num_packs):
+                               bytes_per_pack):
+    hints = collective_util.Hints(bytes_per_pack=bytes_per_pack)
     self._run_between_graph_clients(
         self._test_reduction,
         self._cluster_spec,
         required_gpus,
         communication=CollectiveCommunication.RING,
         use_strategy_object=use_strategy_object,
-        num_packs=num_packs)
+        hints=hints)
 
   @combinations.generate(
       combinations.combine(
diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py
index fa6f612af17..f9917385b59 100644
--- a/tensorflow/python/distribute/cross_device_utils.py
+++ b/tensorflow/python/distribute/cross_device_utils.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import collective_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nccl_ops
+from tensorflow.python.platform import tf_logging as logging
 
 
 OP_INSTANCE_KEY_START_NUMBER = 100
@@ -896,6 +897,67 @@ def stitch_values(values_and_indices_list):
   return result
 
 
+def per_replica_num_elements(per_replica):
+  """Returns the static number of elements of one replica.
+
+  Args:
+    per_replica: A PerReplica of Tensor or IndexedSlices.
+
+  Returns:
+    Number of elements. None if some replica has a different or unknown shape.
+  """
+
+  values = per_replica._values  # pylint: disable=protected-access
+  s0 = values[0].shape
+  for v in values:
+    assert not isinstance(v, ops.IndexedSlices)
+    if v.shape != s0:
+      return None
+  return s0.num_elements()
+
+
+def pack_by_size(per_replica_list, bytes_per_pack):
+  """Packs `per_replica_list` into chunks of `bytes_per_pack`.
+
+  The method preserves the original order of `per_replica_list`. The packing is
+  best effort, each pack could have more or less bytes than `bytes_per_pack`.
+  It only packs values with known shape. Note that, the usage is different from
+  `cross_device_ops._pack_tensors`, this function is intended to work with the
+  ScopeAllocator style batching used in `CollectiveAllReduce`.
+
+  Args:
+    per_replica_list: A list of PerReplica.
+    bytes_per_pack: Bytes per pack.
+
+  Returns:
+    A list of packs of PerReplica. All values are packed into one pack if
+      `bytes_per_pack` is zero or any of the value has unknown shape.
+  """
+
+  if bytes_per_pack == 0:
+    return [per_replica_list]
+  packs = []
+  last_pack_size = 0
+  for value in per_replica_list:
+    num_elements = per_replica_num_elements(value)
+    if num_elements is None:
+      # Can't pack values with unknown shape.
+      logging.warning(
+          'not packing values due to the unknown or inconsistent shape of %s',
+          value)
+      return [per_replica_list]
+    size = num_elements * value._primary.dtype.size  # pylint: disable=protected-access
+    # Try to keep each pack as close to bytes_per_pack as possible, while each
+    # pack is at least bytes_per_pack large. I.E. we err on the side of having
+    # few but large packs.
+    if not packs or last_pack_size > bytes_per_pack:
+      packs.append([])
+      last_pack_size = 0
+    packs[-1].append(value)
+    last_pack_size += size
+  return packs
+
+
 def _control_input(inputs, control_inputs, idx):
   """Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies.
 
diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py
index 217883ea21b..ae0f3b8fe76 100644
--- a/tensorflow/python/distribute/cross_device_utils_test.py
+++ b/tensorflow/python/distribute/cross_device_utils_test.py
@@ -26,8 +26,11 @@ from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import values as value_lib
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import input_layer
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 
 
@@ -133,8 +136,86 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
 
     self.assertIsInstance(result, ops.IndexedSlices)
     self._assert_values_equal(t, result)
-    self.assertEqual(device_util.resolve(destination),
-                     device_util.resolve(result.device))
+    self.assertEqual(
+        device_util.resolve(destination), device_util.resolve(result.device))
+
+
+class PackBySizeTest(test.TestCase):
+
+  def assertShape(self, per_replica, shape):
+    for v in per_replica._values:  # pylint: disable=protected-access
+      self.assertEqual(v.shape, shape)
+
+  def testPreferLargerPack(self):
+    # Each packs except the last one should be equal or larger than
+    # bytes_per_pack.
+    values = [
+        # size = 2 * 4 * 4 * 4 = 128
+        array_ops.ones([2, 4, 4], dtype=dtypes.float32),
+        # size = 8 * 4 = 32
+        array_ops.ones([8], dtype=dtypes.int32),
+        # size = 10 * 10 * 8 = 800
+        array_ops.ones([10, 10], dtype=dtypes.int64),
+        # size = 1 * 4 = 4
+        array_ops.ones([1], dtype=dtypes.int32),
+    ]
+    per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=200)
+    self.assertLen(packs, 2)
+    self.assertLen(packs[0], 3)
+    self.assertShape(packs[0][0], [2, 4, 4])
+    self.assertShape(packs[0][1], [8])
+    self.assertShape(packs[0][2], [10, 10])
+    self.assertLen(packs[1], 1)
+    self.assertShape(packs[1][0], [1])
+
+  def testZeroBytesPerPack(self):
+    values = [
+        array_ops.ones([1], dtype=dtypes.float32),
+        array_ops.ones([2], dtype=dtypes.float32),
+    ]
+    per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=0)
+    self.assertLen(packs, 1)
+    self.assertLen(packs[0], 2)
+    self.assertShape(packs[0][0], [1])
+    self.assertShape(packs[0][1], [2])
+
+  def testUnknownShape(self):
+    per_replica_values = [
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+        ]),
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            input_layer.Input(
+                shape=(10), batch_size=None, dtype=dtypes.float32),
+        ]),
+    ]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=1)
+    self.assertLen(packs, 1)
+    self.assertEqual(packs[0], per_replica_values)
+
+  def testInconsistentShape(self):
+    per_replica_values = [
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+        ]),
+        value_lib.PerReplica([
+            array_ops.ones([10, 10], dtype=dtypes.float32),
+            input_layer.Input(
+                shape=(10), batch_size=None, dtype=dtypes.float32),
+        ]),
+    ]
+    packs = cross_device_utils.pack_by_size(
+        per_replica_values, bytes_per_pack=1)
+    self.assertLen(packs, 1)
+    self.assertEqual(packs[0], per_replica_values)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 7fac3fa3d9e..1d1e44f97c9 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -108,6 +108,7 @@ import six
 from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
 from tensorflow.python.autograph.impl import api as autograph
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import numpy_dataset
@@ -1719,10 +1720,10 @@ class StrategyExtendedV2(object):
   def _reduce(self, reduce_op, value):
     # Default implementation until we have an implementation for each strategy.
     return self._local_results(
-        self._reduce_to(reduce_op, value,
-                        device_util.current() or "/device:CPU:0"))[0]
+        self.reduce_to(reduce_op, value,
+                       device_util.current() or "/device:CPU:0"))[0]
 
-  def reduce_to(self, reduce_op, value, destinations):
+  def reduce_to(self, reduce_op, value, destinations, experimental_hints=None):
     """Combine (via e.g. sum or mean) values across replicas.
 
     Args:
@@ -1732,6 +1733,8 @@ class StrategyExtendedV2(object):
         string. The return value will be copied to all destination devices (or
         all the devices where the `destinations` value resides). To perform an
         all-reduction, pass `value` to `destinations`.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       A tensor or value mirrored to `destinations`.
@@ -1744,18 +1747,25 @@ class StrategyExtendedV2(object):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
     assert (reduce_op == reduce_util.ReduceOp.SUM or
             reduce_op == reduce_util.ReduceOp.MEAN)
-    return self._reduce_to(reduce_op, value, destinations)
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
+    return self._reduce_to(reduce_op, value, destinations, experimental_hints)
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     raise NotImplementedError("must be implemented in descendants")
 
-  def batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def batch_reduce_to(self,
+                      reduce_op,
+                      value_destination_pairs,
+                      experimental_hints=None):
     """Combine multiple `reduce_to` calls into one for faster execution.
 
     Args:
       reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
-      value_destination_pairs: A sequence of (value, destinations)
-        pairs. See `reduce_to()` for a description.
+      value_destination_pairs: A sequence of (value, destinations) pairs. See
+        `reduce_to()` for a description.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
       A list of mirrored values, one per pair in `value_destination_pairs`.
@@ -1765,11 +1775,16 @@ class StrategyExtendedV2(object):
     assert not isinstance(reduce_op, variable_scope.VariableAggregation)
     if isinstance(reduce_op, six.string_types):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
-    return self._batch_reduce_to(reduce_op, value_destination_pairs)
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
+    return self._batch_reduce_to(reduce_op, value_destination_pairs,
+                                 experimental_hints)
 
-  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def _batch_reduce_to(self, reduce_op, value_destination_pairs,
+                       experimental_hints):
     return [
-        self.reduce_to(reduce_op, t, destinations=v)
+        self.reduce_to(
+            reduce_op, t, destinations=v, experimental_hints=experimental_hints)
         for t, v in value_destination_pairs
     ]
 
@@ -2267,7 +2282,7 @@ class ReplicaContext(object):
     require_replica_context(self)
     return (device_util.current(),)
 
-  def all_reduce(self, reduce_op, value):
+  def all_reduce(self, reduce_op, value, experimental_hints=None):
     """All-reduces the given `value Tensor` nest across replicas.
 
     If `all_reduce` is called in any replica, it must be called in all replicas.
@@ -2289,16 +2304,21 @@ class ReplicaContext(object):
       reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
       value: The nested structure of `Tensor`s to all-reduce. The structure must
         be compatible with `tf.nest`.
+      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+        to perform collective operations.
 
     Returns:
        A `Tensor` nest with the reduced `value`s from each replica.
     """
     if isinstance(reduce_op, six.string_types):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
+    if experimental_hints is None:
+      experimental_hints = collective_util.Hints()
 
     def batch_all_reduce(strategy, *value_flat):
       return strategy.extended.batch_reduce_to(
-          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
+          reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
+          experimental_hints)
 
     if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
       # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
@@ -2449,9 +2469,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
       return fn(*args, **kwargs)
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     # TODO(josh11b): Use destinations?
-    del reduce_op, destinations
+    del reduce_op, destinations, experimental_hints
     return value
 
   def _update(self, var, fn, args, kwargs, group):
diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py
index 8c7ad0ae40d..bac623ada52 100644
--- a/tensorflow/python/distribute/distribute_lib_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -95,8 +95,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
   def _local_results(self, value):
     return (value,)
 
-  def _reduce_to(self, reduce_op, value, destinations):
-    del reduce_op, destinations
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
+    del reduce_op, destinations, experimental_hints
     return value
 
   def _experimental_make_numpy_dataset(self, numpy_input, session):
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index e57c656139a..baa6e1ac76e 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -788,7 +788,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
   def _get_cross_device_ops(self):
     return self._cross_device_ops or self._inferred_cross_device_ops
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.Mirrored) and
         reduce_op == reduce_util.ReduceOp.MEAN):
       return value
@@ -801,11 +801,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
       return cross_device_ops_lib.reduce_non_distributed_value(
           reduce_op, value, destinations, self._num_replicas_in_sync)
     return self._get_cross_device_ops().reduce(
-        reduce_op, value, destinations=destinations)
+        reduce_op,
+        value,
+        destinations=destinations,
+        experimental_hints=experimental_hints)
 
-  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def _batch_reduce_to(self, reduce_op, value_destination_pairs,
+                       experimental_hints):
     return self._get_cross_device_ops().batch_reduce(reduce_op,
-                                                     value_destination_pairs)
+                                                     value_destination_pairs,
+                                                     experimental_hints)
 
   def _update(self, var, fn, args, kwargs, group):
     # TODO(josh11b): In eager mode, use one thread per device.
diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py
index 1fb3538b517..5a6973a699b 100644
--- a/tensorflow/python/distribute/one_device_strategy.py
+++ b/tensorflow/python/distribute/one_device_strategy.py
@@ -356,8 +356,8 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
     with ops.device(self._device), _OneDeviceReplicaContext(strategy):
       return fn(*args, **kwargs)
 
-  def _reduce_to(self, reduce_op, value, destinations):
-    del reduce_op, destinations
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
+    del reduce_op, destinations, experimental_hints
     return value
 
   def _update(self, var, fn, args, kwargs, group):
diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py
index d27bacf6be7..f785a0c6266 100644
--- a/tensorflow/python/distribute/parameter_server_strategy.py
+++ b/tensorflow/python/distribute/parameter_server_strategy.py
@@ -466,20 +466,25 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
             "Cannot reduce to another worker: %r, current worker is %r" %
             (d, self._input_workers.worker_devices[0]))
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     self._verify_destinations_not_different_worker(destinations)
     if not isinstance(value, values.DistributedValues):
       # pylint: disable=protected-access
       return cross_device_ops_lib.reduce_non_distributed_value(
           reduce_op, value, destinations, self._num_replicas_in_sync)
     return self._cross_device_ops.reduce(
-        reduce_op, value, destinations=destinations)
+        reduce_op,
+        value,
+        destinations=destinations,
+        experimental_hints=experimental_hints)
 
-  def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+  def _batch_reduce_to(self, reduce_op, value_destination_pairs,
+                       experimental_hints):
     for _, destinations in value_destination_pairs:
       self._verify_destinations_not_different_worker(destinations)
     return self._cross_device_ops.batch_reduce(reduce_op,
-                                               value_destination_pairs)
+                                               value_destination_pairs,
+                                               experimental_hints)
 
   def _select_single_value(self, structured):
     """Select any single value in `structured`."""
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 0ebb2918bb1..7176a2e2dc9 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -659,7 +659,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
                                            tpu_values.TPUSyncOnReadVariable,
                                            **kwargs)
 
-  def _reduce_to(self, reduce_op, value, destinations):
+  def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
     if (isinstance(value, values.DistributedValues) or
         tensor_util.is_tensor(value)
        ) and tpu_values.enclosing_tpu_context() is not None:
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt
index a2ea2343241..0a8e0b4421a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-cross-device-ops.pbtxt
@@ -8,11 +8,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -24,10 +24,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
index a38c4b21d56..b5ccb39d075 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt
index bdc09bcd84b..1a039b10501 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-nccl-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt
index f5ade9f86ba..7876166dc40 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduction-to-one-device.pbtxt
@@ -9,11 +9,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -25,10 +25,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
index c3b79911757..0101212e4cc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
@@ -24,7 +24,7 @@ tf_class {
   }
   member_method {
     name: "all_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "merge_call"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
index 8d5f6daff47..b2d9d4ee2cb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
@@ -37,7 +37,7 @@ tf_class {
   }
   member_method {
     name: "batch_reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "broadcast_to"
@@ -69,7 +69,7 @@ tf_class {
   }
   member_method {
     name: "reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "update"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt
new file mode 100644
index 00000000000..c010134466c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-collective-hints.pbtxt
@@ -0,0 +1,9 @@
+path: "tensorflow.distribute.experimental.CollectiveHints"
+tf_class {
+  is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt
index 879cceec3ac..9247db37925 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.pbtxt
@@ -8,6 +8,10 @@ tf_module {
     name: "CollectiveCommunication"
     mtype: "<class \'enum.EnumMeta\'>"
   }
+  member {
+    name: "CollectiveHints"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "MultiWorkerMirroredStrategy"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt
index a2ea2343241..0a8e0b4421a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-cross-device-ops.pbtxt
@@ -8,11 +8,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -24,10 +24,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
index a38c4b21d56..b5ccb39d075 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-hierarchical-copy-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt
index bdc09bcd84b..1a039b10501 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-nccl-all-reduce.pbtxt
@@ -10,11 +10,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -26,10 +26,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt
index f5ade9f86ba..7876166dc40 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduction-to-one-device.pbtxt
@@ -9,11 +9,11 @@ tf_class {
   }
   member_method {
     name: "batch_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "batch_reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "broadcast"
@@ -25,10 +25,10 @@ tf_class {
   }
   member_method {
     name: "reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "reduce_implementation"
-    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'per_replica_value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=None"
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
index c3b79911757..0101212e4cc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
@@ -24,7 +24,7 @@ tf_class {
   }
   member_method {
     name: "all_reduce"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "merge_call"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
index c5e87b33070..f3fa80427a4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
@@ -20,7 +20,7 @@ tf_class {
   }
   member_method {
     name: "batch_reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "colocate_vars_with"
@@ -32,7 +32,7 @@ tf_class {
   }
   member_method {
     name: "reduce_to"
-    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\', \'experimental_hints\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "update"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt
new file mode 100644
index 00000000000..c010134466c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-collective-hints.pbtxt
@@ -0,0 +1,9 @@
+path: "tensorflow.distribute.experimental.CollectiveHints"
+tf_class {
+  is_instance: "<class \'tensorflow.python.distribute.collective_util.Hints\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'bytes_per_pack\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt
index 879cceec3ac..9247db37925 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt
@@ -8,6 +8,10 @@ tf_module {
     name: "CollectiveCommunication"
     mtype: "<class \'enum.EnumMeta\'>"
   }
+  member {
+    name: "CollectiveHints"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "MultiWorkerMirroredStrategy"
     mtype: "<type \'type\'>"