Support saving mixed precision models with activity regularizers.
Fixes https://github.com/tensorflow/tensorflow/issues/43978 PiperOrigin-RevId: 337973257 Change-Id: Iefb46a31ff6a45d4e5d84fb7cb8822e4ae1cd039
This commit is contained in:
parent
a7467d5d51
commit
ab9b5f5b05
@ -549,15 +549,19 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
use_v1_policy=False):
|
||||
self._skip_if_strategy_unsupported(strategy_fn)
|
||||
self._skip_if_save_format_unsupported(save_format)
|
||||
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
|
||||
else None)
|
||||
if use_regularizer:
|
||||
weight_regularizer = mp_test_util.IdentityRegularizer()
|
||||
activity_regularizer = mp_test_util.ReduceSumRegularizer()
|
||||
else:
|
||||
weight_regularizer = activity_regularizer = None
|
||||
with strategy_fn().scope():
|
||||
cls = policy.PolicyV1 if use_v1_policy else policy.Policy
|
||||
with policy.policy_scope(cls(policy_name)):
|
||||
layer = mp_test_util.MultiplyLayer(
|
||||
assert_type=dtypes.float16,
|
||||
use_operator=use_operator,
|
||||
regularizer=regularizer,
|
||||
regularizer=weight_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
input_shape=(1,))
|
||||
if use_input_spec:
|
||||
layer.input_spec = input_spec.InputSpec(shape=(None, 1))
|
||||
@ -596,8 +600,9 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
# from it.
|
||||
expected = 1 - 2**-14
|
||||
if use_regularizer:
|
||||
# Regularizer adds another 2 ** -14 to the gradient.
|
||||
expected -= 2**-14
|
||||
# Weight and activity regularizer each add another 2 ** -14 to the
|
||||
# gradient.
|
||||
expected -= 2 * 2**-14
|
||||
self.assertEqual(backend.eval(layer.v), expected)
|
||||
|
||||
if save_format:
|
||||
@ -614,14 +619,14 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
if 'MultiplyLayer' in layer.__class__.__name__)
|
||||
expected = 1 - 2**-14
|
||||
if use_regularizer:
|
||||
expected -= 2**-14
|
||||
expected -= 2 * 2**-14
|
||||
self.assertEqual(backend.eval(layer.v), expected)
|
||||
|
||||
# Continue training, and assert variable is correct value
|
||||
model.fit(dataset)
|
||||
new_expected = expected - 2 ** -14
|
||||
if use_regularizer:
|
||||
new_expected -= 2 ** -14
|
||||
new_expected -= 2 * 2 ** -14
|
||||
self.assertEqual(backend.eval(layer.v), new_expected)
|
||||
|
||||
# Load saved model again, and assert variable is previous value
|
||||
|
@ -130,13 +130,15 @@ class MultiplyLayer(AssertTypeLayer):
|
||||
|
||||
def __init__(self,
|
||||
regularizer=None,
|
||||
activity_regularizer=None,
|
||||
use_operator=False,
|
||||
var_name='v',
|
||||
**kwargs):
|
||||
"""Initializes the MultiplyLayer.
|
||||
|
||||
Args:
|
||||
regularizer: The regularizer on the scalar variable.
|
||||
regularizer: The weight regularizer on the scalar variable.
|
||||
activity_regularizer: The activity regularizer.
|
||||
use_operator: If True, add using the * operator. If False, add using
|
||||
tf.multiply.
|
||||
var_name: The name of the variable. It can be useful to pass a name other
|
||||
@ -148,9 +150,15 @@ class MultiplyLayer(AssertTypeLayer):
|
||||
if isinstance(regularizer, dict):
|
||||
self._regularizer = regularizers.deserialize(regularizer,
|
||||
custom_objects=globals())
|
||||
self._activity_regularizer = activity_regularizer
|
||||
if isinstance(activity_regularizer, dict):
|
||||
self._activity_regularizer = regularizers.deserialize(
|
||||
activity_regularizer, custom_objects=globals())
|
||||
|
||||
self._use_operator = use_operator
|
||||
self._var_name = var_name
|
||||
super(MultiplyLayer, self).__init__(**kwargs)
|
||||
super(MultiplyLayer, self).__init__(
|
||||
activity_regularizer=self._activity_regularizer, **kwargs)
|
||||
|
||||
def build(self, _):
|
||||
self.v = self.add_weight(
|
||||
@ -170,6 +178,8 @@ class MultiplyLayer(AssertTypeLayer):
|
||||
def get_config(self):
|
||||
config = super(MultiplyLayer, self).get_config()
|
||||
config['regularizer'] = regularizers.serialize(self._regularizer)
|
||||
config['activity_regularizer'] = regularizers.serialize(
|
||||
self._activity_regularizer)
|
||||
config['use_operator'] = self._use_operator
|
||||
config['var_name'] = self._var_name
|
||||
config['assert_type'] = self._assert_type
|
||||
@ -184,3 +194,12 @@ class IdentityRegularizer(regularizers.Regularizer):
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
|
||||
class ReduceSumRegularizer(regularizers.Regularizer):
|
||||
|
||||
def __call__(self, x):
|
||||
return math_ops.reduce_sum(x)
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
@ -628,7 +628,9 @@ def _wrap_activity_regularizer(layer):
|
||||
return def_function.Function(
|
||||
layer._activity_regularizer,
|
||||
'{}_activity_regularizer'.format(layer.name),
|
||||
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
|
||||
input_signature=[
|
||||
tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx())
|
||||
])
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user