Run mixed precision tests in more cases.

PiperOrigin-RevId: 296053767
Change-Id: I4bcc64b9f09046b23cab0fd76e017f581242bfee
This commit is contained in:
Reed Wanderman-Milne 2020-02-19 14:09:46 -08:00 committed by TensorFlower Gardener
parent 6edc8c2a9a
commit 2e667319a5

View File

@ -421,13 +421,11 @@ class KerasLayerTest(keras_parameterized.TestCase):
class KerasModelTest(keras_parameterized.TestCase): class KerasModelTest(keras_parameterized.TestCase):
"""Test mixed precision with Keras models.""" """Test mixed precision with Keras models."""
def _skip_if_strategy_unsupported(self, strategy_fn, check_model_type=False): def _skip_if_strategy_unsupported(self, strategy_fn):
if (strategy_fn != default_strategy_fn and if (strategy_fn != default_strategy_fn and
(testing_utils.should_run_eagerly() or testing_utils.get_model_type() == 'subclass'):
(check_model_type and testing_utils.get_model_type() == 'subclass'))):
self.skipTest('Non-default strategies are unsupported with subclassed ' self.skipTest('Non-default strategies are unsupported with subclassed '
'models or with passing run_eagerly=True to ' 'models')
'Model.compile()')
def _skip_if_save_format_unsupported(self, save_format): def _skip_if_save_format_unsupported(self, save_format):
model_type = testing_utils.get_model_type() model_type = testing_utils.get_model_type()
@ -435,8 +433,8 @@ class KerasModelTest(keras_parameterized.TestCase):
self.skipTest('Saving subclassed models with the HDF5 format is ' self.skipTest('Saving subclassed models with the HDF5 format is '
'unsupported') 'unsupported')
if (save_format == 'tf' and model_type == 'subclass' and if (save_format == 'tf' and model_type == 'subclass' and
not testing_utils.should_run_tf_function()): not context.executing_eagerly()):
self.skipTest('b/142352416: This combination of features is currently ' self.skipTest('b/148820505: This combination of features is currently '
'broken.') 'broken.')
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@ -494,11 +492,10 @@ class KerasModelTest(keras_parameterized.TestCase):
'save_format': 'h5', 'save_format': 'h5',
'use_regularizer': True, 'use_regularizer': True,
}, { }, {
# TODO(b/148874820): Test saving a model with CentralStorageStrategy.
# Currently this doesn't work even for float32.
'testcase_name': 'central_storage', 'testcase_name': 'central_storage',
'strategy_fn': create_central_storage_strategy, 'strategy_fn': create_central_storage_strategy,
'use_regularizer': True, 'use_regularizer': True,
'save_format': 'tf'
}, { }, {
'testcase_name': 'norun_distributed', 'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
@ -513,7 +510,7 @@ class KerasModelTest(keras_parameterized.TestCase):
save_format=None, save_format=None,
use_input_spec=False, use_input_spec=False,
experimental_run_tf_function=True): experimental_run_tf_function=True):
self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True) self._skip_if_strategy_unsupported(strategy_fn)
self._skip_if_save_format_unsupported(save_format) self._skip_if_save_format_unsupported(save_format)
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
else None) else None)
@ -620,7 +617,6 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy_fn, strategy_fn,
experimental_run_tf_function=True): experimental_run_tf_function=True):
# Note: We do not test mixed precision in this method, only loss scaling. # Note: We do not test mixed precision in this method, only loss scaling.
self._skip_if_strategy_unsupported(strategy_fn)
loss_scale = 8. loss_scale = 8.
batch_size = 4 batch_size = 4
with strategy_fn().scope(): with strategy_fn().scope():
@ -679,7 +675,6 @@ class KerasModelTest(keras_parameterized.TestCase):
# * Regularization on some variables and not others. # * Regularization on some variables and not others.
# * A fixed loss scale (if use_loss_scaling is True) # * A fixed loss scale (if use_loss_scaling is True)
self._skip_if_strategy_unsupported(strategy_fn)
strategy = strategy_fn() strategy = strategy_fn()
if use_loss_scaling: if use_loss_scaling:
loss_scale = 8. loss_scale = 8.
@ -779,7 +774,6 @@ class KerasModelTest(keras_parameterized.TestCase):
pass_loss_scale_to_policy=False, pass_loss_scale_to_policy=False,
get_config=False, get_config=False,
experimental_run_tf_function=True): experimental_run_tf_function=True):
self._skip_if_strategy_unsupported(strategy_fn)
strategy = strategy_fn() strategy = strategy_fn()
initial_loss_scale = 2. initial_loss_scale = 2.
batch_size = 4 batch_size = 4
@ -956,7 +950,6 @@ class KerasModelTest(keras_parameterized.TestCase):
def test_save_slot_variables_with_autocast_vars(self, def test_save_slot_variables_with_autocast_vars(self,
strategy_fn, strategy_fn,
var_name='v'): var_name='v'):
self._skip_if_strategy_unsupported(strategy_fn)
p = policy.Policy('mixed_float16', loss_scale=None) p = policy.Policy('mixed_float16', loss_scale=None)
with strategy_fn().scope(), policy.policy_scope(p): with strategy_fn().scope(), policy.policy_scope(p):
x = layers.Input(shape=(2,), batch_size=2) x = layers.Input(shape=(2,), batch_size=2)
@ -992,7 +985,6 @@ class KerasModelTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(*TESTCASES) @parameterized.named_parameters(*TESTCASES)
def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn): def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
self._skip_if_strategy_unsupported(strategy_fn)
strategy = strategy_fn() strategy = strategy_fn()
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
not context.executing_eagerly()): not context.executing_eagerly()):
@ -1051,7 +1043,6 @@ class KerasModelTest(keras_parameterized.TestCase):
'h5': True, 'h5': True,
}) })
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False): def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
self._skip_if_strategy_unsupported(strategy_fn)
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
# as well. # as well.
strategy = strategy_fn() strategy = strategy_fn()