Fix the bug in collective_all_reduce_strategy that wrong cross device op is used.
PiperOrigin-RevId: 225096446
This commit is contained in:
parent
e8c65fa77f
commit
45a6696c0a
@ -70,6 +70,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
self._cross_device_ops = None
|
self._cross_device_ops = None
|
||||||
self._num_gpus_per_worker = num_gpus_per_worker
|
self._num_gpus_per_worker = num_gpus_per_worker
|
||||||
self._initialize_local_worker(num_gpus_per_worker)
|
self._initialize_local_worker(num_gpus_per_worker)
|
||||||
|
assert isinstance(self._get_cross_device_ops(),
|
||||||
|
cross_device_ops_lib.CollectiveAllReduce)
|
||||||
|
|
||||||
def _initialize_local_worker(self, num_gpus_per_worker):
|
def _initialize_local_worker(self, num_gpus_per_worker):
|
||||||
"""Initializes the object for local training."""
|
"""Initializes the object for local training."""
|
||||||
@ -86,7 +88,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
|
|
||||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||||
self._initialize_local(local_devices)
|
self._initialize_local(local_devices)
|
||||||
self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce(
|
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||||
num_workers=self._num_workers,
|
num_workers=self._num_workers,
|
||||||
num_gpus_per_worker=num_gpus_per_worker,
|
num_gpus_per_worker=num_gpus_per_worker,
|
||||||
collective_keys=self._collective_keys)
|
collective_keys=self._collective_keys)
|
||||||
@ -128,7 +130,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
|
|
||||||
self._collective_keys = cross_device_utils.CollectiveKeys()
|
self._collective_keys = cross_device_utils.CollectiveKeys()
|
||||||
self._initialize_local(local_devices)
|
self._initialize_local(local_devices)
|
||||||
self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce(
|
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||||
num_workers=self._num_workers,
|
num_workers=self._num_workers,
|
||||||
num_gpus_per_worker=num_gpus_per_worker,
|
num_gpus_per_worker=num_gpus_per_worker,
|
||||||
collective_keys=self._collective_keys)
|
collective_keys=self._collective_keys)
|
||||||
@ -267,6 +269,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
# already been initialized with a `cluster_spec`.
|
# already been initialized with a `cluster_spec`.
|
||||||
self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
|
self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
|
||||||
task_type, task_id)
|
task_type, task_id)
|
||||||
|
assert isinstance(self._get_cross_device_ops(),
|
||||||
|
cross_device_ops_lib.CollectiveAllReduce)
|
||||||
|
|
||||||
if session_config:
|
if session_config:
|
||||||
session_config.CopyFrom(self._update_config_proto(session_config))
|
session_config.CopyFrom(self._update_config_proto(session_config))
|
||||||
|
@ -82,7 +82,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
instance_key_with_id_start=num_gpus * 10000 +
|
instance_key_with_id_start=num_gpus * 10000 +
|
||||||
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
CollectiveAllReduceStrategyTestBase.collective_key_base)
|
||||||
distribution.extended._collective_keys = collective_keys
|
distribution.extended._collective_keys = collective_keys
|
||||||
distribution.extended._inferred_cross_device_ops._collective_keys = (
|
distribution.extended._cross_device_ops._collective_keys = (
|
||||||
collective_keys)
|
collective_keys)
|
||||||
if task_type and task_id is not None:
|
if task_type and task_id is not None:
|
||||||
return distribution, 'grpc://' + self._cluster_spec[task_type][
|
return distribution, 'grpc://' + self._cluster_spec[task_type][
|
||||||
|
Loading…
Reference in New Issue
Block a user