diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 1b82261462e..c5aca728827 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -803,7 +803,10 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): def reduce_implementation(self, reduce_op, per_replica_value, destinations, options): del options # Unused. - if _devices_match(per_replica_value, destinations): + # To use NCCL or all-reduce, source and destination devices should match, + # and none of the devices should be CPU. + if (_devices_match(per_replica_value, destinations) and + not any("cpu" in d.lower() for d in get_devices_from(destinations))): return self._batch_all_reduce(reduce_op, [per_replica_value])[0] else: return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 5d09096596f..2a2428994be 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -1456,7 +1456,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, input_options): def dataset_fn(input_context): # pylint: disable=[unused-argument] - return dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]) ds = distribution.experimental_distribute_datasets_from_function( dataset_fn, input_options)