diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 5c50a204904..346513dc586 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -70,6 +70,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local_worker(num_gpus_per_worker) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) def _initialize_local_worker(self, num_gpus_per_worker): """Initializes the object for local training.""" @@ -86,7 +88,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys) @@ -128,7 +130,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._collective_keys = cross_device_utils.CollectiveKeys() self._initialize_local(local_devices) - self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys) @@ -267,6 +269,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): # already been initialized with a `cluster_spec`. self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, task_type, task_id) + assert isinstance(self._get_cross_device_ops(), + cross_device_ops_lib.CollectiveAllReduce) if session_config: session_config.CopyFrom(self._update_config_proto(session_config)) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 8a9e583f0af..6d7cd14ed5a 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -82,7 +82,7 @@ class CollectiveAllReduceStrategyTestBase( instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) distribution.extended._collective_keys = collective_keys - distribution.extended._inferred_cross_device_ops._collective_keys = ( + distribution.extended._cross_device_ops._collective_keys = ( collective_keys) if task_type and task_id is not None: return distribution, 'grpc://' + self._cluster_spec[task_type][