Fix the bug in collective_all_reduce_strategy that wrong cross device op is used.
PiperOrigin-RevId: 225096446
This commit is contained in:
parent
e8c65fa77f
commit
45a6696c0a
@ -70,6 +70,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._cross_device_ops = None
|
||||
self._num_gpus_per_worker = num_gpus_per_worker
|
||||
self._initialize_local_worker(num_gpus_per_worker)
|
||||
assert isinstance(self._get_cross_device_ops(),
|
||||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
def _initialize_local_worker(self, num_gpus_per_worker):
|
||||
"""Initializes the object for local training."""
|
||||
@ -86,7 +88,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
self._initialize_local(local_devices)
|
||||
self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
num_workers=self._num_workers,
|
||||
num_gpus_per_worker=num_gpus_per_worker,
|
||||
collective_keys=self._collective_keys)
|
||||
@ -128,7 +130,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||
self._initialize_local(local_devices)
|
||||
self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
num_workers=self._num_workers,
|
||||
num_gpus_per_worker=num_gpus_per_worker,
|
||||
collective_keys=self._collective_keys)
|
||||
@ -267,6 +269,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
# already been initialized with a `cluster_spec`.
|
||||
self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
|
||||
task_type, task_id)
|
||||
assert isinstance(self._get_cross_device_ops(),
|
||||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
if session_config:
|
||||
session_config.CopyFrom(self._update_config_proto(session_config))
|
||||
|
@ -82,7 +82,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
instance_key_with_id_start=num_gpus * 10000 +
|
||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||
distribution.extended._collective_keys = collective_keys
|
||||
distribution.extended._inferred_cross_device_ops._collective_keys = (
|
||||
distribution.extended._cross_device_ops._collective_keys = (
|
||||
collective_keys)
|
||||
if task_type and task_id is not None:
|
||||
return distribution, 'grpc://' + self._cluster_spec[task_type][
|
||||
|
Loading…
Reference in New Issue
Block a user