Loose the check for regularizer and convert None to default value.

PiperOrigin-RevId: 319092652
Change-Id: If75a7d3474290f510d92fd6ebcfaf371f5fe6fcb
This commit is contained in:
Scott Zhu 2020-06-30 13:54:48 -07:00 committed by TensorFlower Gardener
parent 0d37c8f114
commit 1873cf2efe
2 changed files with 24 additions and 1 deletions

View File

@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import keras_export
def _check_penalty_number(x):
"""check penalty number availability, raise ValueError if failed"""
"""check penalty number availability, raise ValueError if failed."""
if not isinstance(x, (float, int)):
raise ValueError(('Value: {} is not a valid regularization penalty number, '
'expected an int or float value').format(x))
@ -43,6 +43,10 @@ def _check_penalty_number(x):
).format(x))
def _none_to_default(inputs, default):
return default if inputs is None else default
@keras_export('keras.regularizers.Regularizer')
class Regularizer(object):
"""Regularizer base class.
@ -230,6 +234,11 @@ class L1L2(Regularizer):
"""
def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name
# The default value for l1 and l2 are different from the value in l1_l2
# for backward compatiblity reason. Eg, L1L2(l2=0.1) will only have l2
# and no l1 penalty.
l1 = 0. if l1 is None else l1
l2 = 0. if l2 is None else l2
_check_penalty_number(l1)
_check_penalty_number(l2)
@ -270,6 +279,7 @@ class L1(Regularizer):
if kwargs:
raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
l1 = 0.01 if l1 is None else l1
_check_penalty_number(l1)
self.l1 = backend.cast_to_floatx(l1)
@ -303,6 +313,7 @@ class L2(Regularizer):
if kwargs:
raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
l2 = 0.01 if l2 is None else l2
_check_penalty_number(l2)
self.l2 = backend.cast_to_floatx(l2)

View File

@ -203,6 +203,18 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
with self.assertRaisesRegex(ValueError, 'Could not interpret regularizer'):
keras.regularizers.get(0)
@parameterized.named_parameters([
('l1', regularizers.l1(l1=None), 0.01),
('l2', regularizers.l2(l2=None), 0.01),
('l1_l2', regularizers.l1_l2(l1=None, l2=None), 0.),
])
def test_default_value_when_init_with_none(self, regularizer, expected_value):
expected_value = np.asarray(expected_value)
if hasattr(regularizer, 'l1'):
self.assertAllClose(regularizer.l1, expected_value)
if hasattr(regularizer, 'l2'):
self.assertAllClose(regularizer.l2, expected_value)
if __name__ == '__main__':
test.main()