Do not use NCCL when reducing tensors on CPUs.
PiperOrigin-RevId: 338387045 Change-Id: I9c2f4d8b9831d7102bb6d0df3d3c9ba1be3720d1
This commit is contained in:
parent
60cf96348b
commit
95a74d1a98
@ -803,7 +803,10 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|||||||
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
||||||
options):
|
options):
|
||||||
del options # Unused.
|
del options # Unused.
|
||||||
if _devices_match(per_replica_value, destinations):
|
# To use NCCL or all-reduce, source and destination devices should match,
|
||||||
|
# and none of the devices should be CPU.
|
||||||
|
if (_devices_match(per_replica_value, destinations) and
|
||||||
|
not any("cpu" in d.lower() for d in get_devices_from(destinations))):
|
||||||
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
||||||
else:
|
else:
|
||||||
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
|
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
|
||||||
|
@ -1456,7 +1456,7 @@ class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
|
|||||||
input_options):
|
input_options):
|
||||||
|
|
||||||
def dataset_fn(input_context): # pylint: disable=[unused-argument]
|
def dataset_fn(input_context): # pylint: disable=[unused-argument]
|
||||||
return dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4])
|
||||||
|
|
||||||
ds = distribution.experimental_distribute_datasets_from_function(
|
ds = distribution.experimental_distribute_datasets_from_function(
|
||||||
dataset_fn, input_options)
|
dataset_fn, input_options)
|
||||||
|
Loading…
Reference in New Issue
Block a user