From 1feb592358c3ab4dfe89ea7848993c8201b860c8 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 30 Nov 2020 19:18:05 -0800 Subject: [PATCH] Remove the usage of _DefaultDistributionStrategy in keras. PiperOrigin-RevId: 344934166 Change-Id: I7787e3ab2500796ec3bbe3e0f4a307f559be77b2 --- tensorflow/python/keras/callbacks.py | 6 ++---- .../distribute/keras_correctness_test_base.py | 20 ++++--------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 493e7383752..f395179e5f1 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -35,7 +35,6 @@ import six from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import collective_all_reduce_strategy -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import context @@ -1603,7 +1602,6 @@ class BackupAndRestore(Callback): self.backup_dir = backup_dir self._supports_tf_logs = True self._supported_strategies = ( - distribute_lib._DefaultDistributionStrategy, mirrored_strategy.MirroredStrategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy, tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2) @@ -1631,8 +1629,8 @@ class BackupAndRestore(Callback): # failure-recovery of a worker in training. # pylint: disable=protected-access - if not isinstance(self.model.distribute_strategy, - self._supported_strategies): + if self.model._distribution_strategy and not isinstance( + self.model.distribute_strategy, self._supported_strategies): raise NotImplementedError( '%s is not supported yet. ' 'Currently BackupAndRestore callback only supports empty strategy, ' diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 1b28564ff73..37a63a5774b 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -24,15 +24,12 @@ import numpy as np import six from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import central_storage_strategy -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import one_device_strategy from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import context from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_combinations as combinations +from tensorflow.python.framework import test_util from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.distribute.strategy_combinations import all_strategies from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies @@ -324,18 +321,9 @@ def compare_results(results_with_ds, def _get_compare_result_tolerance(key): """Returns tolerance to compare results.""" - # TODO(b/119257215): For MirroredStrategy, weights are not exactly the same, - # so use larger tolerance for now. Predict should be related to weights. - # Also for CentralStorageStrategy and OneDeviceStrategy which is observed in - # b/172956754. - if (isinstance(distribution, - (mirrored_strategy.MirroredStrategy, - mirrored_strategy.MirroredStrategyV1, - central_storage_strategy.CentralStorageStrategy, - central_storage_strategy.CentralStorageStrategyV1, - one_device_strategy.OneDeviceStrategy, - one_device_strategy.OneDeviceStrategyV1, - distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access + # See b/119257215 for more details. DS test run on GPU could have larger + # variance then test on CPU. + if (test_util.is_gpu_available() and key.startswith(('weights_1', 'weights_2', 'predict_result'))): return relaxed_tolerance