Remove the usage of _DefaultDistributionStrategy in keras.
PiperOrigin-RevId: 344934166 Change-Id: I7787e3ab2500796ec3bbe3e0f4a307f559be77b2
This commit is contained in:
		
							parent
							
								
									4e17145925
								
							
						
					
					
						commit
						1feb592358
					
				@ -35,7 +35,6 @@ import six
 | 
			
		||||
from tensorflow.core.framework import summary_pb2
 | 
			
		||||
from tensorflow.python.data.ops import iterator_ops
 | 
			
		||||
from tensorflow.python.distribute import collective_all_reduce_strategy
 | 
			
		||||
from tensorflow.python.distribute import distribute_lib
 | 
			
		||||
from tensorflow.python.distribute import mirrored_strategy
 | 
			
		||||
from tensorflow.python.distribute import tpu_strategy
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
@ -1603,7 +1602,6 @@ class BackupAndRestore(Callback):
 | 
			
		||||
    self.backup_dir = backup_dir
 | 
			
		||||
    self._supports_tf_logs = True
 | 
			
		||||
    self._supported_strategies = (
 | 
			
		||||
        distribute_lib._DefaultDistributionStrategy,
 | 
			
		||||
        mirrored_strategy.MirroredStrategy,
 | 
			
		||||
        collective_all_reduce_strategy.CollectiveAllReduceStrategy,
 | 
			
		||||
        tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)
 | 
			
		||||
@ -1631,8 +1629,8 @@ class BackupAndRestore(Callback):
 | 
			
		||||
    # failure-recovery of a worker in training.
 | 
			
		||||
    # pylint: disable=protected-access
 | 
			
		||||
 | 
			
		||||
    if not isinstance(self.model.distribute_strategy,
 | 
			
		||||
                      self._supported_strategies):
 | 
			
		||||
    if self.model._distribution_strategy and not isinstance(
 | 
			
		||||
        self.model.distribute_strategy, self._supported_strategies):
 | 
			
		||||
      raise NotImplementedError(
 | 
			
		||||
          '%s is not supported yet. '
 | 
			
		||||
          'Currently BackupAndRestore callback only supports empty strategy, '
 | 
			
		||||
 | 
			
		||||
@ -24,15 +24,12 @@ import numpy as np
 | 
			
		||||
import six
 | 
			
		||||
from tensorflow.python import keras
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.distribute import central_storage_strategy
 | 
			
		||||
from tensorflow.python.distribute import distribute_lib
 | 
			
		||||
from tensorflow.python.distribute import mirrored_strategy
 | 
			
		||||
from tensorflow.python.distribute import one_device_strategy
 | 
			
		||||
from tensorflow.python.distribute import strategy_combinations
 | 
			
		||||
from tensorflow.python.distribute import tpu_strategy
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.framework import random_seed
 | 
			
		||||
from tensorflow.python.framework import test_combinations as combinations
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
from tensorflow.python.keras.distribute import distributed_training_utils
 | 
			
		||||
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 | 
			
		||||
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
 | 
			
		||||
@ -324,18 +321,9 @@ def compare_results(results_with_ds,
 | 
			
		||||
 | 
			
		||||
  def _get_compare_result_tolerance(key):
 | 
			
		||||
    """Returns tolerance to compare results."""
 | 
			
		||||
    # TODO(b/119257215): For MirroredStrategy, weights are not exactly the same,
 | 
			
		||||
    # so use larger tolerance for now. Predict should be related to weights.
 | 
			
		||||
    # Also for CentralStorageStrategy and OneDeviceStrategy which is observed in
 | 
			
		||||
    # b/172956754.
 | 
			
		||||
    if (isinstance(distribution,
 | 
			
		||||
                   (mirrored_strategy.MirroredStrategy,
 | 
			
		||||
                    mirrored_strategy.MirroredStrategyV1,
 | 
			
		||||
                    central_storage_strategy.CentralStorageStrategy,
 | 
			
		||||
                    central_storage_strategy.CentralStorageStrategyV1,
 | 
			
		||||
                    one_device_strategy.OneDeviceStrategy,
 | 
			
		||||
                    one_device_strategy.OneDeviceStrategyV1,
 | 
			
		||||
                    distribute_lib._DefaultDistributionStrategy)) and  # pylint: disable=protected-access
 | 
			
		||||
    # See b/119257215 for more details. DS test run on GPU could have larger
 | 
			
		||||
    # variance then test on CPU.
 | 
			
		||||
    if (test_util.is_gpu_available() and
 | 
			
		||||
        key.startswith(('weights_1', 'weights_2', 'predict_result'))):
 | 
			
		||||
      return relaxed_tolerance
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user