Remove the usage of _DefaultDistributionStrategy in keras.

PiperOrigin-RevId: 344934166
Change-Id: I7787e3ab2500796ec3bbe3e0f4a307f559be77b2
This commit is contained in:
Scott Zhu 2020-11-30 19:18:05 -08:00 committed by TensorFlower Gardener
parent 4e17145925
commit 1feb592358
2 changed files with 6 additions and 20 deletions

View File

@ -35,7 +35,6 @@ import six
from tensorflow.core.framework import summary_pb2 from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -1603,7 +1602,6 @@ class BackupAndRestore(Callback):
self.backup_dir = backup_dir self.backup_dir = backup_dir
self._supports_tf_logs = True self._supports_tf_logs = True
self._supported_strategies = ( self._supported_strategies = (
distribute_lib._DefaultDistributionStrategy,
mirrored_strategy.MirroredStrategy, mirrored_strategy.MirroredStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy,
tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2) tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)
@ -1631,8 +1629,8 @@ class BackupAndRestore(Callback):
# failure-recovery of a worker in training. # failure-recovery of a worker in training.
# pylint: disable=protected-access # pylint: disable=protected-access
if not isinstance(self.model.distribute_strategy, if self.model._distribution_strategy and not isinstance(
self._supported_strategies): self.model.distribute_strategy, self._supported_strategies):
raise NotImplementedError( raise NotImplementedError(
'%s is not supported yet. ' '%s is not supported yet. '
'Currently BackupAndRestore callback only supports empty strategy, ' 'Currently BackupAndRestore callback only supports empty strategy, '

View File

@ -24,15 +24,12 @@ import numpy as np
import six import six
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util
from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
@ -324,18 +321,9 @@ def compare_results(results_with_ds,
def _get_compare_result_tolerance(key): def _get_compare_result_tolerance(key):
"""Returns tolerance to compare results.""" """Returns tolerance to compare results."""
# TODO(b/119257215): For MirroredStrategy, weights are not exactly the same, # See b/119257215 for more details. DS test run on GPU could have larger
# so use larger tolerance for now. Predict should be related to weights. # variance then test on CPU.
# Also for CentralStorageStrategy and OneDeviceStrategy which is observed in if (test_util.is_gpu_available() and
# b/172956754.
if (isinstance(distribution,
(mirrored_strategy.MirroredStrategy,
mirrored_strategy.MirroredStrategyV1,
central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1,
one_device_strategy.OneDeviceStrategy,
one_device_strategy.OneDeviceStrategyV1,
distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access
key.startswith(('weights_1', 'weights_2', 'predict_result'))): key.startswith(('weights_1', 'weights_2', 'predict_result'))):
return relaxed_tolerance return relaxed_tolerance