Run mixed precision tests in more cases.
PiperOrigin-RevId: 296053767 Change-Id: I4bcc64b9f09046b23cab0fd76e017f581242bfee
This commit is contained in:
parent
6edc8c2a9a
commit
2e667319a5
@ -421,13 +421,11 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
|||||||
class KerasModelTest(keras_parameterized.TestCase):
|
class KerasModelTest(keras_parameterized.TestCase):
|
||||||
"""Test mixed precision with Keras models."""
|
"""Test mixed precision with Keras models."""
|
||||||
|
|
||||||
def _skip_if_strategy_unsupported(self, strategy_fn, check_model_type=False):
|
def _skip_if_strategy_unsupported(self, strategy_fn):
|
||||||
if (strategy_fn != default_strategy_fn and
|
if (strategy_fn != default_strategy_fn and
|
||||||
(testing_utils.should_run_eagerly() or
|
testing_utils.get_model_type() == 'subclass'):
|
||||||
(check_model_type and testing_utils.get_model_type() == 'subclass'))):
|
|
||||||
self.skipTest('Non-default strategies are unsupported with subclassed '
|
self.skipTest('Non-default strategies are unsupported with subclassed '
|
||||||
'models or with passing run_eagerly=True to '
|
'models')
|
||||||
'Model.compile()')
|
|
||||||
|
|
||||||
def _skip_if_save_format_unsupported(self, save_format):
|
def _skip_if_save_format_unsupported(self, save_format):
|
||||||
model_type = testing_utils.get_model_type()
|
model_type = testing_utils.get_model_type()
|
||||||
@ -435,8 +433,8 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
self.skipTest('Saving subclassed models with the HDF5 format is '
|
self.skipTest('Saving subclassed models with the HDF5 format is '
|
||||||
'unsupported')
|
'unsupported')
|
||||||
if (save_format == 'tf' and model_type == 'subclass' and
|
if (save_format == 'tf' and model_type == 'subclass' and
|
||||||
not testing_utils.should_run_tf_function()):
|
not context.executing_eagerly()):
|
||||||
self.skipTest('b/142352416: This combination of features is currently '
|
self.skipTest('b/148820505: This combination of features is currently '
|
||||||
'broken.')
|
'broken.')
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@ -494,11 +492,10 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
'save_format': 'h5',
|
'save_format': 'h5',
|
||||||
'use_regularizer': True,
|
'use_regularizer': True,
|
||||||
}, {
|
}, {
|
||||||
# TODO(b/148874820): Test saving a model with CentralStorageStrategy.
|
|
||||||
# Currently this doesn't work even for float32.
|
|
||||||
'testcase_name': 'central_storage',
|
'testcase_name': 'central_storage',
|
||||||
'strategy_fn': create_central_storage_strategy,
|
'strategy_fn': create_central_storage_strategy,
|
||||||
'use_regularizer': True,
|
'use_regularizer': True,
|
||||||
|
'save_format': 'tf'
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'norun_distributed',
|
'testcase_name': 'norun_distributed',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
@ -513,7 +510,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
save_format=None,
|
save_format=None,
|
||||||
use_input_spec=False,
|
use_input_spec=False,
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True)
|
self._skip_if_strategy_unsupported(strategy_fn)
|
||||||
self._skip_if_save_format_unsupported(save_format)
|
self._skip_if_save_format_unsupported(save_format)
|
||||||
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
|
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
|
||||||
else None)
|
else None)
|
||||||
@ -620,7 +617,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
strategy_fn,
|
strategy_fn,
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
# Note: We do not test mixed precision in this method, only loss scaling.
|
# Note: We do not test mixed precision in this method, only loss scaling.
|
||||||
self._skip_if_strategy_unsupported(strategy_fn)
|
|
||||||
loss_scale = 8.
|
loss_scale = 8.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
with strategy_fn().scope():
|
with strategy_fn().scope():
|
||||||
@ -679,7 +675,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
# * Regularization on some variables and not others.
|
# * Regularization on some variables and not others.
|
||||||
# * A fixed loss scale (if use_loss_scaling is True)
|
# * A fixed loss scale (if use_loss_scaling is True)
|
||||||
|
|
||||||
self._skip_if_strategy_unsupported(strategy_fn)
|
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
if use_loss_scaling:
|
if use_loss_scaling:
|
||||||
loss_scale = 8.
|
loss_scale = 8.
|
||||||
@ -779,7 +774,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
pass_loss_scale_to_policy=False,
|
pass_loss_scale_to_policy=False,
|
||||||
get_config=False,
|
get_config=False,
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
self._skip_if_strategy_unsupported(strategy_fn)
|
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
initial_loss_scale = 2.
|
initial_loss_scale = 2.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
@ -956,7 +950,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
def test_save_slot_variables_with_autocast_vars(self,
|
def test_save_slot_variables_with_autocast_vars(self,
|
||||||
strategy_fn,
|
strategy_fn,
|
||||||
var_name='v'):
|
var_name='v'):
|
||||||
self._skip_if_strategy_unsupported(strategy_fn)
|
|
||||||
p = policy.Policy('mixed_float16', loss_scale=None)
|
p = policy.Policy('mixed_float16', loss_scale=None)
|
||||||
with strategy_fn().scope(), policy.policy_scope(p):
|
with strategy_fn().scope(), policy.policy_scope(p):
|
||||||
x = layers.Input(shape=(2,), batch_size=2)
|
x = layers.Input(shape=(2,), batch_size=2)
|
||||||
@ -992,7 +985,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
|
def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
|
||||||
self._skip_if_strategy_unsupported(strategy_fn)
|
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
@ -1051,7 +1043,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
'h5': True,
|
'h5': True,
|
||||||
})
|
})
|
||||||
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
||||||
self._skip_if_strategy_unsupported(strategy_fn)
|
|
||||||
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
|
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
|
||||||
# as well.
|
# as well.
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user