From 1873cf2efee08f04d6a5045b401642bf6e5a3a30 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 30 Jun 2020 13:54:48 -0700 Subject: [PATCH] Loose the check for regularizer and convert None to default value. PiperOrigin-RevId: 319092652 Change-Id: If75a7d3474290f510d92fd6ebcfaf371f5fe6fcb --- tensorflow/python/keras/regularizers.py | 13 ++++++++++++- tensorflow/python/keras/regularizers_test.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/regularizers.py b/tensorflow/python/keras/regularizers.py index 919d30220b3..53791df78ac 100644 --- a/tensorflow/python/keras/regularizers.py +++ b/tensorflow/python/keras/regularizers.py @@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import keras_export def _check_penalty_number(x): - """check penalty number availability, raise ValueError if failed""" + """check penalty number availability, raise ValueError if failed.""" if not isinstance(x, (float, int)): raise ValueError(('Value: {} is not a valid regularization penalty number, ' 'expected an int or float value').format(x)) @@ -43,6 +43,10 @@ def _check_penalty_number(x): ).format(x)) +def _none_to_default(inputs, default): + return default if inputs is None else default + + @keras_export('keras.regularizers.Regularizer') class Regularizer(object): """Regularizer base class. @@ -230,6 +234,11 @@ class L1L2(Regularizer): """ def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name + # The default value for l1 and l2 are different from the value in l1_l2 + # for backward compatiblity reason. Eg, L1L2(l2=0.1) will only have l2 + # and no l1 penalty. + l1 = 0. if l1 is None else l1 + l2 = 0. if l2 is None else l2 _check_penalty_number(l1) _check_penalty_number(l2) @@ -270,6 +279,7 @@ class L1(Regularizer): if kwargs: raise TypeError('Argument(s) not recognized: %s' % (kwargs,)) + l1 = 0.01 if l1 is None else l1 _check_penalty_number(l1) self.l1 = backend.cast_to_floatx(l1) @@ -303,6 +313,7 @@ class L2(Regularizer): if kwargs: raise TypeError('Argument(s) not recognized: %s' % (kwargs,)) + l2 = 0.01 if l2 is None else l2 _check_penalty_number(l2) self.l2 = backend.cast_to_floatx(l2) diff --git a/tensorflow/python/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py index b10218ba114..21cecd03e6a 100644 --- a/tensorflow/python/keras/regularizers_test.py +++ b/tensorflow/python/keras/regularizers_test.py @@ -203,6 +203,18 @@ class KerasRegularizersTest(keras_parameterized.TestCase, with self.assertRaisesRegex(ValueError, 'Could not interpret regularizer'): keras.regularizers.get(0) + @parameterized.named_parameters([ + ('l1', regularizers.l1(l1=None), 0.01), + ('l2', regularizers.l2(l2=None), 0.01), + ('l1_l2', regularizers.l1_l2(l1=None, l2=None), 0.), + ]) + def test_default_value_when_init_with_none(self, regularizer, expected_value): + expected_value = np.asarray(expected_value) + if hasattr(regularizer, 'l1'): + self.assertAllClose(regularizer.l1, expected_value) + if hasattr(regularizer, 'l2'): + self.assertAllClose(regularizer.l2, expected_value) + if __name__ == '__main__': test.main()