Run mixed precision tests in more cases.

PiperOrigin-RevId: 296053767
Change-Id: I4bcc64b9f09046b23cab0fd76e017f581242bfee
This commit is contained in:
Reed Wanderman-Milne 2020-02-19 14:09:46 -08:00 committed by TensorFlower Gardener
parent 6edc8c2a9a
commit 2e667319a5

View File

@ -421,13 +421,11 @@ class KerasLayerTest(keras_parameterized.TestCase):
class KerasModelTest(keras_parameterized.TestCase):
"""Test mixed precision with Keras models."""
def _skip_if_strategy_unsupported(self, strategy_fn, check_model_type=False):
def _skip_if_strategy_unsupported(self, strategy_fn):
if (strategy_fn != default_strategy_fn and
(testing_utils.should_run_eagerly() or
(check_model_type and testing_utils.get_model_type() == 'subclass'))):
testing_utils.get_model_type() == 'subclass'):
self.skipTest('Non-default strategies are unsupported with subclassed '
'models or with passing run_eagerly=True to '
'Model.compile()')
'models')
def _skip_if_save_format_unsupported(self, save_format):
model_type = testing_utils.get_model_type()
@ -435,8 +433,8 @@ class KerasModelTest(keras_parameterized.TestCase):
self.skipTest('Saving subclassed models with the HDF5 format is '
'unsupported')
if (save_format == 'tf' and model_type == 'subclass' and
not testing_utils.should_run_tf_function()):
self.skipTest('b/142352416: This combination of features is currently '
not context.executing_eagerly()):
self.skipTest('b/148820505: This combination of features is currently '
'broken.')
@keras_parameterized.run_with_all_model_types
@ -494,11 +492,10 @@ class KerasModelTest(keras_parameterized.TestCase):
'save_format': 'h5',
'use_regularizer': True,
}, {
# TODO(b/148874820): Test saving a model with CentralStorageStrategy.
# Currently this doesn't work even for float32.
'testcase_name': 'central_storage',
'strategy_fn': create_central_storage_strategy,
'use_regularizer': True,
'save_format': 'tf'
}, {
'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy,
@ -513,7 +510,7 @@ class KerasModelTest(keras_parameterized.TestCase):
save_format=None,
use_input_spec=False,
experimental_run_tf_function=True):
self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True)
self._skip_if_strategy_unsupported(strategy_fn)
self._skip_if_save_format_unsupported(save_format)
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
else None)
@ -620,7 +617,6 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy_fn,
experimental_run_tf_function=True):
# Note: We do not test mixed precision in this method, only loss scaling.
self._skip_if_strategy_unsupported(strategy_fn)
loss_scale = 8.
batch_size = 4
with strategy_fn().scope():
@ -679,7 +675,6 @@ class KerasModelTest(keras_parameterized.TestCase):
# * Regularization on some variables and not others.
# * A fixed loss scale (if use_loss_scaling is True)
self._skip_if_strategy_unsupported(strategy_fn)
strategy = strategy_fn()
if use_loss_scaling:
loss_scale = 8.
@ -779,7 +774,6 @@ class KerasModelTest(keras_parameterized.TestCase):
pass_loss_scale_to_policy=False,
get_config=False,
experimental_run_tf_function=True):
self._skip_if_strategy_unsupported(strategy_fn)
strategy = strategy_fn()
initial_loss_scale = 2.
batch_size = 4
@ -956,7 +950,6 @@ class KerasModelTest(keras_parameterized.TestCase):
def test_save_slot_variables_with_autocast_vars(self,
strategy_fn,
var_name='v'):
self._skip_if_strategy_unsupported(strategy_fn)
p = policy.Policy('mixed_float16', loss_scale=None)
with strategy_fn().scope(), policy.policy_scope(p):
x = layers.Input(shape=(2,), batch_size=2)
@ -992,7 +985,6 @@ class KerasModelTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(*TESTCASES)
def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
self._skip_if_strategy_unsupported(strategy_fn)
strategy = strategy_fn()
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
not context.executing_eagerly()):
@ -1051,7 +1043,6 @@ class KerasModelTest(keras_parameterized.TestCase):
'h5': True,
})
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
self._skip_if_strategy_unsupported(strategy_fn)
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
# as well.
strategy = strategy_fn()