Fix the bug in collective_all_reduce_strategy that wrong cross device op is used.

PiperOrigin-RevId: 225096446
This commit is contained in:
Yuefeng Zhou 2018-12-11 16:55:34 -08:00 committed by TensorFlower Gardener
parent e8c65fa77f
commit 45a6696c0a
2 changed files with 7 additions and 3 deletions

View File

@ -70,6 +70,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._cross_device_ops = None self._cross_device_ops = None
self._num_gpus_per_worker = num_gpus_per_worker self._num_gpus_per_worker = num_gpus_per_worker
self._initialize_local_worker(num_gpus_per_worker) self._initialize_local_worker(num_gpus_per_worker)
assert isinstance(self._get_cross_device_ops(),
cross_device_ops_lib.CollectiveAllReduce)
def _initialize_local_worker(self, num_gpus_per_worker): def _initialize_local_worker(self, num_gpus_per_worker):
"""Initializes the object for local training.""" """Initializes the object for local training."""
@ -86,7 +88,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._collective_keys = cross_device_utils.CollectiveKeys() self._collective_keys = cross_device_utils.CollectiveKeys()
self._initialize_local(local_devices) self._initialize_local(local_devices)
self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
num_workers=self._num_workers, num_workers=self._num_workers,
num_gpus_per_worker=num_gpus_per_worker, num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys) collective_keys=self._collective_keys)
@ -128,7 +130,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._collective_keys = cross_device_utils.CollectiveKeys() self._collective_keys = cross_device_utils.CollectiveKeys()
self._initialize_local(local_devices) self._initialize_local(local_devices)
self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
num_workers=self._num_workers, num_workers=self._num_workers,
num_gpus_per_worker=num_gpus_per_worker, num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys) collective_keys=self._collective_keys)
@ -267,6 +269,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
# already been initialized with a `cluster_spec`. # already been initialized with a `cluster_spec`.
self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
task_type, task_id) task_type, task_id)
assert isinstance(self._get_cross_device_ops(),
cross_device_ops_lib.CollectiveAllReduce)
if session_config: if session_config:
session_config.CopyFrom(self._update_config_proto(session_config)) session_config.CopyFrom(self._update_config_proto(session_config))

View File

@ -82,7 +82,7 @@ class CollectiveAllReduceStrategyTestBase(
instance_key_with_id_start=num_gpus * 10000 + instance_key_with_id_start=num_gpus * 10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base) CollectiveAllReduceStrategyTestBase.collective_key_base)
distribution.extended._collective_keys = collective_keys distribution.extended._collective_keys = collective_keys
distribution.extended._inferred_cross_device_ops._collective_keys = ( distribution.extended._cross_device_ops._collective_keys = (
collective_keys) collective_keys)
if task_type and task_id is not None: if task_type and task_id is not None:
return distribution, 'grpc://' + self._cluster_spec[task_type][ return distribution, 'grpc://' + self._cluster_spec[task_type][