Remove the usage of _DefaultDistributionStrategy in keras.
PiperOrigin-RevId: 344934166 Change-Id: I7787e3ab2500796ec3bbe3e0f4a307f559be77b2
This commit is contained in:
parent
4e17145925
commit
1feb592358
@ -35,7 +35,6 @@ import six
|
|||||||
from tensorflow.core.framework import summary_pb2
|
from tensorflow.core.framework import summary_pb2
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||||
from tensorflow.python.distribute import distribute_lib
|
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
from tensorflow.python.distribute import tpu_strategy
|
from tensorflow.python.distribute import tpu_strategy
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -1603,7 +1602,6 @@ class BackupAndRestore(Callback):
|
|||||||
self.backup_dir = backup_dir
|
self.backup_dir = backup_dir
|
||||||
self._supports_tf_logs = True
|
self._supports_tf_logs = True
|
||||||
self._supported_strategies = (
|
self._supported_strategies = (
|
||||||
distribute_lib._DefaultDistributionStrategy,
|
|
||||||
mirrored_strategy.MirroredStrategy,
|
mirrored_strategy.MirroredStrategy,
|
||||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
|
||||||
tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)
|
tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)
|
||||||
@ -1631,8 +1629,8 @@ class BackupAndRestore(Callback):
|
|||||||
# failure-recovery of a worker in training.
|
# failure-recovery of a worker in training.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
if not isinstance(self.model.distribute_strategy,
|
if self.model._distribution_strategy and not isinstance(
|
||||||
self._supported_strategies):
|
self.model.distribute_strategy, self._supported_strategies):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'%s is not supported yet. '
|
'%s is not supported yet. '
|
||||||
'Currently BackupAndRestore callback only supports empty strategy, '
|
'Currently BackupAndRestore callback only supports empty strategy, '
|
||||||
|
@ -24,15 +24,12 @@ import numpy as np
|
|||||||
import six
|
import six
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import central_storage_strategy
|
|
||||||
from tensorflow.python.distribute import distribute_lib
|
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
|
||||||
from tensorflow.python.distribute import one_device_strategy
|
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.distribute import tpu_strategy
|
from tensorflow.python.distribute import tpu_strategy
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import test_combinations as combinations
|
from tensorflow.python.framework import test_combinations as combinations
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||||
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
|
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
|
||||||
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
|
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
|
||||||
@ -324,18 +321,9 @@ def compare_results(results_with_ds,
|
|||||||
|
|
||||||
def _get_compare_result_tolerance(key):
|
def _get_compare_result_tolerance(key):
|
||||||
"""Returns tolerance to compare results."""
|
"""Returns tolerance to compare results."""
|
||||||
# TODO(b/119257215): For MirroredStrategy, weights are not exactly the same,
|
# See b/119257215 for more details. DS test run on GPU could have larger
|
||||||
# so use larger tolerance for now. Predict should be related to weights.
|
# variance then test on CPU.
|
||||||
# Also for CentralStorageStrategy and OneDeviceStrategy which is observed in
|
if (test_util.is_gpu_available() and
|
||||||
# b/172956754.
|
|
||||||
if (isinstance(distribution,
|
|
||||||
(mirrored_strategy.MirroredStrategy,
|
|
||||||
mirrored_strategy.MirroredStrategyV1,
|
|
||||||
central_storage_strategy.CentralStorageStrategy,
|
|
||||||
central_storage_strategy.CentralStorageStrategyV1,
|
|
||||||
one_device_strategy.OneDeviceStrategy,
|
|
||||||
one_device_strategy.OneDeviceStrategyV1,
|
|
||||||
distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access
|
|
||||||
key.startswith(('weights_1', 'weights_2', 'predict_result'))):
|
key.startswith(('weights_1', 'weights_2', 'predict_result'))):
|
||||||
return relaxed_tolerance
|
return relaxed_tolerance
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user