Fix cross_device_ops_test with multi GPU

The test doesn't limit the number of GPUs when creating
CollectiveAllReduceStrategy, which results in mismatch in the number of devices.

PiperOrigin-RevId: 316769983
Change-Id: I50c11107a99348162b37615f209fd6fa6ee659d2
This commit is contained in:
Ran Chen 2020-06-16 15:27:59 -07:00 committed by TensorFlower Gardener
parent 0ea67f633b
commit 31e83f2f68
2 changed files with 10 additions and 4 deletions

View File

@ -112,9 +112,12 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
"num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
@classmethod
def _from_local_devices(cls, devices):
def _from_local_devices(
cls,
devices,
communication=cross_device_ops_lib.CollectiveCommunication.AUTO):
"""A convenience method to create an object with a list of devices."""
obj = cls()
obj = cls(communication)
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
return obj

View File

@ -521,10 +521,13 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
devices = ["/device:CPU:0"]
if use_strategy_object:
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
communication=communication)
strategy = (
collective_all_reduce_strategy.CollectiveAllReduceStrategy
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
strategy.extended._collective_keys = collective_keys
strategy.extended._cross_device_ops._collective_keys = collective_keys
strategy.extended._host_cross_device_ops._collective_keys = (
collective_keys)
return strategy, devices, ""
else:
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(