From 2e667319a5e18c0b1caafb2f7c4f8387a1ab747e Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Wed, 19 Feb 2020 14:09:46 -0800 Subject: [PATCH] Run mixed precision tests in more cases. PiperOrigin-RevId: 296053767 Change-Id: I4bcc64b9f09046b23cab0fd76e017f581242bfee --- .../experimental/keras_test.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index f1bf1f2bde2..8ec8d914cf5 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -421,13 +421,11 @@ class KerasLayerTest(keras_parameterized.TestCase): class KerasModelTest(keras_parameterized.TestCase): """Test mixed precision with Keras models.""" - def _skip_if_strategy_unsupported(self, strategy_fn, check_model_type=False): + def _skip_if_strategy_unsupported(self, strategy_fn): if (strategy_fn != default_strategy_fn and - (testing_utils.should_run_eagerly() or - (check_model_type and testing_utils.get_model_type() == 'subclass'))): + testing_utils.get_model_type() == 'subclass'): self.skipTest('Non-default strategies are unsupported with subclassed ' - 'models or with passing run_eagerly=True to ' - 'Model.compile()') + 'models') def _skip_if_save_format_unsupported(self, save_format): model_type = testing_utils.get_model_type() @@ -435,8 +433,8 @@ class KerasModelTest(keras_parameterized.TestCase): self.skipTest('Saving subclassed models with the HDF5 format is ' 'unsupported') if (save_format == 'tf' and model_type == 'subclass' and - not testing_utils.should_run_tf_function()): - self.skipTest('b/142352416: This combination of features is currently ' + not context.executing_eagerly()): + self.skipTest('b/148820505: This combination of features is currently ' 'broken.') @keras_parameterized.run_with_all_model_types @@ -494,11 +492,10 @@ class KerasModelTest(keras_parameterized.TestCase): 'save_format': 'h5', 'use_regularizer': True, }, { - # TODO(b/148874820): Test saving a model with CentralStorageStrategy. - # Currently this doesn't work even for float32. 'testcase_name': 'central_storage', 'strategy_fn': create_central_storage_strategy, 'use_regularizer': True, + 'save_format': 'tf' }, { 'testcase_name': 'norun_distributed', 'strategy_fn': create_mirrored_strategy, @@ -513,7 +510,7 @@ class KerasModelTest(keras_parameterized.TestCase): save_format=None, use_input_spec=False, experimental_run_tf_function=True): - self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True) + self._skip_if_strategy_unsupported(strategy_fn) self._skip_if_save_format_unsupported(save_format) regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer else None) @@ -620,7 +617,6 @@ class KerasModelTest(keras_parameterized.TestCase): strategy_fn, experimental_run_tf_function=True): # Note: We do not test mixed precision in this method, only loss scaling. - self._skip_if_strategy_unsupported(strategy_fn) loss_scale = 8. batch_size = 4 with strategy_fn().scope(): @@ -679,7 +675,6 @@ class KerasModelTest(keras_parameterized.TestCase): # * Regularization on some variables and not others. # * A fixed loss scale (if use_loss_scaling is True) - self._skip_if_strategy_unsupported(strategy_fn) strategy = strategy_fn() if use_loss_scaling: loss_scale = 8. @@ -779,7 +774,6 @@ class KerasModelTest(keras_parameterized.TestCase): pass_loss_scale_to_policy=False, get_config=False, experimental_run_tf_function=True): - self._skip_if_strategy_unsupported(strategy_fn) strategy = strategy_fn() initial_loss_scale = 2. batch_size = 4 @@ -956,7 +950,6 @@ class KerasModelTest(keras_parameterized.TestCase): def test_save_slot_variables_with_autocast_vars(self, strategy_fn, var_name='v'): - self._skip_if_strategy_unsupported(strategy_fn) p = policy.Policy('mixed_float16', loss_scale=None) with strategy_fn().scope(), policy.policy_scope(p): x = layers.Input(shape=(2,), batch_size=2) @@ -992,7 +985,6 @@ class KerasModelTest(keras_parameterized.TestCase): @keras_parameterized.run_all_keras_modes @parameterized.named_parameters(*TESTCASES) def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn): - self._skip_if_strategy_unsupported(strategy_fn) strategy = strategy_fn() if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and not context.executing_eagerly()): @@ -1051,7 +1043,6 @@ class KerasModelTest(keras_parameterized.TestCase): 'h5': True, }) def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False): - self._skip_if_strategy_unsupported(strategy_fn) # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy # as well. strategy = strategy_fn()