Fix cross_device_ops_test with multi GPU
The test doesn't limit the number of GPUs when creating CollectiveAllReduceStrategy, which results in mismatch in the number of devices. PiperOrigin-RevId: 316769983 Change-Id: I50c11107a99348162b37615f209fd6fa6ee659d2
This commit is contained in:
parent
0ea67f633b
commit
31e83f2f68
|
@ -112,9 +112,12 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
|||
"num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
|
||||
|
||||
@classmethod
|
||||
def _from_local_devices(cls, devices):
|
||||
def _from_local_devices(
|
||||
cls,
|
||||
devices,
|
||||
communication=cross_device_ops_lib.CollectiveCommunication.AUTO):
|
||||
"""A convenience method to create an object with a list of devices."""
|
||||
obj = cls()
|
||||
obj = cls(communication)
|
||||
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||
return obj
|
||||
|
||||
|
|
|
@ -521,10 +521,13 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
|||
devices = ["/device:CPU:0"]
|
||||
|
||||
if use_strategy_object:
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||
communication=communication)
|
||||
strategy = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy
|
||||
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
|
||||
strategy.extended._collective_keys = collective_keys
|
||||
strategy.extended._cross_device_ops._collective_keys = collective_keys
|
||||
strategy.extended._host_cross_device_ops._collective_keys = (
|
||||
collective_keys)
|
||||
return strategy, devices, ""
|
||||
else:
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
|
|
Loading…
Reference in New Issue