From 95a74d1a988e63b6ebd9e5e5d15bcab0450b9456 Mon Sep 17 00:00:00 2001
From: Priya Gupta <priyag@google.com>
Date: Wed, 21 Oct 2020 19:19:34 -0700
Subject: [PATCH] Do not use NCCL when reducing tensors on CPUs.

PiperOrigin-RevId: 338387045
Change-Id: I9c2f4d8b9831d7102bb6d0df3d3c9ba1be3720d1
---
 tensorflow/python/distribute/cross_device_ops.py | 5 ++++-
 tensorflow/python/distribute/input_lib_test.py   | 2 +-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index 1b82261462e..c5aca728827 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -803,7 +803,10 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
   def reduce_implementation(self, reduce_op, per_replica_value, destinations,
                             options):
     del options  # Unused.
-    if _devices_match(per_replica_value, destinations):
+    # To use NCCL or all-reduce, source and destination devices should match,
+    # and none of the devices should be CPU.
+    if (_devices_match(per_replica_value, destinations) and
+        not any("cpu" in d.lower() for d in get_devices_from(destinations))):
       return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
     else:
       return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index 5d09096596f..2a2428994be 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -1456,7 +1456,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
                                                         input_options):
 
     def dataset_fn(input_context):  # pylint: disable=[unused-argument]
-      return dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
+      return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4])
 
     ds = distribution.experimental_distribute_datasets_from_function(
         dataset_fn, input_options)