Support saving mixed precision models with activity regularizers.

Fixes https://github.com/tensorflow/tensorflow/issues/43978

PiperOrigin-RevId: 337973257
Change-Id: Iefb46a31ff6a45d4e5d84fb7cb8822e4ae1cd039
This commit is contained in:
Reed Wanderman-Milne 2020-10-19 18:05:28 -07:00 committed by TensorFlower Gardener
parent a7467d5d51
commit ab9b5f5b05
3 changed files with 36 additions and 10 deletions

View File

@ -549,15 +549,19 @@ class KerasModelTest(keras_parameterized.TestCase):
use_v1_policy=False):
self._skip_if_strategy_unsupported(strategy_fn)
self._skip_if_save_format_unsupported(save_format)
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
else None)
if use_regularizer:
weight_regularizer = mp_test_util.IdentityRegularizer()
activity_regularizer = mp_test_util.ReduceSumRegularizer()
else:
weight_regularizer = activity_regularizer = None
with strategy_fn().scope():
cls = policy.PolicyV1 if use_v1_policy else policy.Policy
with policy.policy_scope(cls(policy_name)):
layer = mp_test_util.MultiplyLayer(
assert_type=dtypes.float16,
use_operator=use_operator,
regularizer=regularizer,
regularizer=weight_regularizer,
activity_regularizer=activity_regularizer,
input_shape=(1,))
if use_input_spec:
layer.input_spec = input_spec.InputSpec(shape=(None, 1))
@ -596,8 +600,9 @@ class KerasModelTest(keras_parameterized.TestCase):
# from it.
expected = 1 - 2**-14
if use_regularizer:
# Regularizer adds another 2 ** -14 to the gradient.
expected -= 2**-14
# Weight and activity regularizer each add another 2 ** -14 to the
# gradient.
expected -= 2 * 2**-14
self.assertEqual(backend.eval(layer.v), expected)
if save_format:
@ -614,14 +619,14 @@ class KerasModelTest(keras_parameterized.TestCase):
if 'MultiplyLayer' in layer.__class__.__name__)
expected = 1 - 2**-14
if use_regularizer:
expected -= 2**-14
expected -= 2 * 2**-14
self.assertEqual(backend.eval(layer.v), expected)
# Continue training, and assert variable is correct value
model.fit(dataset)
new_expected = expected - 2 ** -14
if use_regularizer:
new_expected -= 2 ** -14
new_expected -= 2 * 2 ** -14
self.assertEqual(backend.eval(layer.v), new_expected)
# Load saved model again, and assert variable is previous value

View File

@ -130,13 +130,15 @@ class MultiplyLayer(AssertTypeLayer):
def __init__(self,
regularizer=None,
activity_regularizer=None,
use_operator=False,
var_name='v',
**kwargs):
"""Initializes the MultiplyLayer.
Args:
regularizer: The regularizer on the scalar variable.
regularizer: The weight regularizer on the scalar variable.
activity_regularizer: The activity regularizer.
use_operator: If True, add using the * operator. If False, add using
tf.multiply.
var_name: The name of the variable. It can be useful to pass a name other
@ -148,9 +150,15 @@ class MultiplyLayer(AssertTypeLayer):
if isinstance(regularizer, dict):
self._regularizer = regularizers.deserialize(regularizer,
custom_objects=globals())
self._activity_regularizer = activity_regularizer
if isinstance(activity_regularizer, dict):
self._activity_regularizer = regularizers.deserialize(
activity_regularizer, custom_objects=globals())
self._use_operator = use_operator
self._var_name = var_name
super(MultiplyLayer, self).__init__(**kwargs)
super(MultiplyLayer, self).__init__(
activity_regularizer=self._activity_regularizer, **kwargs)
def build(self, _):
self.v = self.add_weight(
@ -170,6 +178,8 @@ class MultiplyLayer(AssertTypeLayer):
def get_config(self):
config = super(MultiplyLayer, self).get_config()
config['regularizer'] = regularizers.serialize(self._regularizer)
config['activity_regularizer'] = regularizers.serialize(
self._activity_regularizer)
config['use_operator'] = self._use_operator
config['var_name'] = self._var_name
config['assert_type'] = self._assert_type
@ -184,3 +194,12 @@ class IdentityRegularizer(regularizers.Regularizer):
def get_config(self):
return {}
class ReduceSumRegularizer(regularizers.Regularizer):
def __call__(self, x):
return math_ops.reduce_sum(x)
def get_config(self):
return {}

View File

@ -628,7 +628,9 @@ def _wrap_activity_regularizer(layer):
return def_function.Function(
layer._activity_regularizer,
'{}_activity_regularizer'.format(layer.name),
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
input_signature=[
tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx())
])
# pylint: enable=protected-access