Run mixed precision tests in more cases.
PiperOrigin-RevId: 296053767 Change-Id: I4bcc64b9f09046b23cab0fd76e017f581242bfee
This commit is contained in:
parent
6edc8c2a9a
commit
2e667319a5
@ -421,13 +421,11 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
||||
class KerasModelTest(keras_parameterized.TestCase):
|
||||
"""Test mixed precision with Keras models."""
|
||||
|
||||
def _skip_if_strategy_unsupported(self, strategy_fn, check_model_type=False):
|
||||
def _skip_if_strategy_unsupported(self, strategy_fn):
|
||||
if (strategy_fn != default_strategy_fn and
|
||||
(testing_utils.should_run_eagerly() or
|
||||
(check_model_type and testing_utils.get_model_type() == 'subclass'))):
|
||||
testing_utils.get_model_type() == 'subclass'):
|
||||
self.skipTest('Non-default strategies are unsupported with subclassed '
|
||||
'models or with passing run_eagerly=True to '
|
||||
'Model.compile()')
|
||||
'models')
|
||||
|
||||
def _skip_if_save_format_unsupported(self, save_format):
|
||||
model_type = testing_utils.get_model_type()
|
||||
@ -435,8 +433,8 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
self.skipTest('Saving subclassed models with the HDF5 format is '
|
||||
'unsupported')
|
||||
if (save_format == 'tf' and model_type == 'subclass' and
|
||||
not testing_utils.should_run_tf_function()):
|
||||
self.skipTest('b/142352416: This combination of features is currently '
|
||||
not context.executing_eagerly()):
|
||||
self.skipTest('b/148820505: This combination of features is currently '
|
||||
'broken.')
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@ -494,11 +492,10 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
'save_format': 'h5',
|
||||
'use_regularizer': True,
|
||||
}, {
|
||||
# TODO(b/148874820): Test saving a model with CentralStorageStrategy.
|
||||
# Currently this doesn't work even for float32.
|
||||
'testcase_name': 'central_storage',
|
||||
'strategy_fn': create_central_storage_strategy,
|
||||
'use_regularizer': True,
|
||||
'save_format': 'tf'
|
||||
}, {
|
||||
'testcase_name': 'norun_distributed',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
@ -513,7 +510,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
save_format=None,
|
||||
use_input_spec=False,
|
||||
experimental_run_tf_function=True):
|
||||
self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True)
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
self._skip_if_save_format_unsupported(save_format)
|
||||
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
|
||||
else None)
|
||||
@ -620,7 +617,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
strategy_fn,
|
||||
experimental_run_tf_function=True):
|
||||
# Note: We do not test mixed precision in this method, only loss scaling.
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
loss_scale = 8.
|
||||
batch_size = 4
|
||||
with strategy_fn().scope():
|
||||
@ -679,7 +675,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
# * Regularization on some variables and not others.
|
||||
# * A fixed loss scale (if use_loss_scaling is True)
|
||||
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
strategy = strategy_fn()
|
||||
if use_loss_scaling:
|
||||
loss_scale = 8.
|
||||
@ -779,7 +774,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
pass_loss_scale_to_policy=False,
|
||||
get_config=False,
|
||||
experimental_run_tf_function=True):
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
strategy = strategy_fn()
|
||||
initial_loss_scale = 2.
|
||||
batch_size = 4
|
||||
@ -956,7 +950,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
def test_save_slot_variables_with_autocast_vars(self,
|
||||
strategy_fn,
|
||||
var_name='v'):
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
p = policy.Policy('mixed_float16', loss_scale=None)
|
||||
with strategy_fn().scope(), policy.policy_scope(p):
|
||||
x = layers.Input(shape=(2,), batch_size=2)
|
||||
@ -992,7 +985,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
strategy = strategy_fn()
|
||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||
not context.executing_eagerly()):
|
||||
@ -1051,7 +1043,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
'h5': True,
|
||||
})
|
||||
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
|
||||
# as well.
|
||||
strategy = strategy_fn()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user