Run mixed precision tests in more cases.
PiperOrigin-RevId: 296053767 Change-Id: I4bcc64b9f09046b23cab0fd76e017f581242bfee
This commit is contained in:
		
							parent
							
								
									6edc8c2a9a
								
							
						
					
					
						commit
						2e667319a5
					
				@ -421,13 +421,11 @@ class KerasLayerTest(keras_parameterized.TestCase):
 | 
			
		||||
class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
  """Test mixed precision with Keras models."""
 | 
			
		||||
 | 
			
		||||
  def _skip_if_strategy_unsupported(self, strategy_fn, check_model_type=False):
 | 
			
		||||
  def _skip_if_strategy_unsupported(self, strategy_fn):
 | 
			
		||||
    if (strategy_fn != default_strategy_fn and
 | 
			
		||||
        (testing_utils.should_run_eagerly() or
 | 
			
		||||
         (check_model_type and testing_utils.get_model_type() == 'subclass'))):
 | 
			
		||||
        testing_utils.get_model_type() == 'subclass'):
 | 
			
		||||
      self.skipTest('Non-default strategies are unsupported with subclassed '
 | 
			
		||||
                    'models or with passing run_eagerly=True to '
 | 
			
		||||
                    'Model.compile()')
 | 
			
		||||
                    'models')
 | 
			
		||||
 | 
			
		||||
  def _skip_if_save_format_unsupported(self, save_format):
 | 
			
		||||
    model_type = testing_utils.get_model_type()
 | 
			
		||||
@ -435,8 +433,8 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
      self.skipTest('Saving subclassed models with the HDF5 format is '
 | 
			
		||||
                    'unsupported')
 | 
			
		||||
    if (save_format == 'tf' and model_type == 'subclass' and
 | 
			
		||||
        not testing_utils.should_run_tf_function()):
 | 
			
		||||
      self.skipTest('b/142352416: This combination of features is currently '
 | 
			
		||||
        not context.executing_eagerly()):
 | 
			
		||||
      self.skipTest('b/148820505: This combination of features is currently '
 | 
			
		||||
                    'broken.')
 | 
			
		||||
 | 
			
		||||
  @keras_parameterized.run_with_all_model_types
 | 
			
		||||
@ -494,11 +492,10 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
          'save_format': 'h5',
 | 
			
		||||
          'use_regularizer': True,
 | 
			
		||||
      }, {
 | 
			
		||||
          # TODO(b/148874820): Test saving a model with CentralStorageStrategy.
 | 
			
		||||
          # Currently this doesn't work even for float32.
 | 
			
		||||
          'testcase_name': 'central_storage',
 | 
			
		||||
          'strategy_fn': create_central_storage_strategy,
 | 
			
		||||
          'use_regularizer': True,
 | 
			
		||||
          'save_format': 'tf'
 | 
			
		||||
      }, {
 | 
			
		||||
          'testcase_name': 'norun_distributed',
 | 
			
		||||
          'strategy_fn': create_mirrored_strategy,
 | 
			
		||||
@ -513,7 +510,7 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
                 save_format=None,
 | 
			
		||||
                 use_input_spec=False,
 | 
			
		||||
                 experimental_run_tf_function=True):
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True)
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    self._skip_if_save_format_unsupported(save_format)
 | 
			
		||||
    regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
 | 
			
		||||
                   else None)
 | 
			
		||||
@ -620,7 +617,6 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
                              strategy_fn,
 | 
			
		||||
                              experimental_run_tf_function=True):
 | 
			
		||||
    # Note: We do not test mixed precision in this method, only loss scaling.
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    loss_scale = 8.
 | 
			
		||||
    batch_size = 4
 | 
			
		||||
    with strategy_fn().scope():
 | 
			
		||||
@ -679,7 +675,6 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
    #  * Regularization on some variables and not others.
 | 
			
		||||
    #  * A fixed loss scale (if use_loss_scaling is True)
 | 
			
		||||
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    strategy = strategy_fn()
 | 
			
		||||
    if use_loss_scaling:
 | 
			
		||||
      loss_scale = 8.
 | 
			
		||||
@ -779,7 +774,6 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
                                pass_loss_scale_to_policy=False,
 | 
			
		||||
                                get_config=False,
 | 
			
		||||
                                experimental_run_tf_function=True):
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    strategy = strategy_fn()
 | 
			
		||||
    initial_loss_scale = 2.
 | 
			
		||||
    batch_size = 4
 | 
			
		||||
@ -956,7 +950,6 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
  def test_save_slot_variables_with_autocast_vars(self,
 | 
			
		||||
                                                  strategy_fn,
 | 
			
		||||
                                                  var_name='v'):
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    p = policy.Policy('mixed_float16', loss_scale=None)
 | 
			
		||||
    with strategy_fn().scope(), policy.policy_scope(p):
 | 
			
		||||
      x = layers.Input(shape=(2,), batch_size=2)
 | 
			
		||||
@ -992,7 +985,6 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
  @keras_parameterized.run_all_keras_modes
 | 
			
		||||
  @parameterized.named_parameters(*TESTCASES)
 | 
			
		||||
  def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    strategy = strategy_fn()
 | 
			
		||||
    if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
 | 
			
		||||
        not context.executing_eagerly()):
 | 
			
		||||
@ -1051,7 +1043,6 @@ class KerasModelTest(keras_parameterized.TestCase):
 | 
			
		||||
          'h5': True,
 | 
			
		||||
      })
 | 
			
		||||
  def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
 | 
			
		||||
    self._skip_if_strategy_unsupported(strategy_fn)
 | 
			
		||||
    # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
 | 
			
		||||
    # as well.
 | 
			
		||||
    strategy = strategy_fn()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user