Remove the usage of _DefaultDistributionStrategy in keras.
PiperOrigin-RevId: 344934166 Change-Id: I7787e3ab2500796ec3bbe3e0f4a307f559be77b2
This commit is contained in:
parent
4e17145925
commit
1feb592358
@ -35,7 +35,6 @@ import six
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.eager import context
|
||||
@ -1603,7 +1602,6 @@ class BackupAndRestore(Callback):
|
||||
self.backup_dir = backup_dir
|
||||
self._supports_tf_logs = True
|
||||
self._supported_strategies = (
|
||||
distribute_lib._DefaultDistributionStrategy,
|
||||
mirrored_strategy.MirroredStrategy,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||
tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)
|
||||
@ -1631,8 +1629,8 @@ class BackupAndRestore(Callback):
|
||||
# failure-recovery of a worker in training.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
if not isinstance(self.model.distribute_strategy,
|
||||
self._supported_strategies):
|
||||
if self.model._distribution_strategy and not isinstance(
|
||||
self.model.distribute_strategy, self._supported_strategies):
|
||||
raise NotImplementedError(
|
||||
'%s is not supported yet. '
|
||||
'Currently BackupAndRestore callback only supports empty strategy, '
|
||||
|
@ -24,15 +24,12 @@ import numpy as np
|
||||
import six
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import one_device_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
|
||||
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
|
||||
@ -324,18 +321,9 @@ def compare_results(results_with_ds,
|
||||
|
||||
def _get_compare_result_tolerance(key):
|
||||
"""Returns tolerance to compare results."""
|
||||
# TODO(b/119257215): For MirroredStrategy, weights are not exactly the same,
|
||||
# so use larger tolerance for now. Predict should be related to weights.
|
||||
# Also for CentralStorageStrategy and OneDeviceStrategy which is observed in
|
||||
# b/172956754.
|
||||
if (isinstance(distribution,
|
||||
(mirrored_strategy.MirroredStrategy,
|
||||
mirrored_strategy.MirroredStrategyV1,
|
||||
central_storage_strategy.CentralStorageStrategy,
|
||||
central_storage_strategy.CentralStorageStrategyV1,
|
||||
one_device_strategy.OneDeviceStrategy,
|
||||
one_device_strategy.OneDeviceStrategyV1,
|
||||
distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access
|
||||
# See b/119257215 for more details. DS test run on GPU could have larger
|
||||
# variance then test on CPU.
|
||||
if (test_util.is_gpu_available() and
|
||||
key.startswith(('weights_1', 'weights_2', 'predict_result'))):
|
||||
return relaxed_tolerance
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user