Add support of scale_l1 == 0. or scale_l2 == 0 to l1_l2_regularizer.

Added tests.

PiperOrigin-RevId: 158179790
This commit is contained in:
Sergio Guadarrama 2017-06-06 13:14:02 -07:00 committed by TensorFlower Gardener
parent b65eb3f9b5
commit 96cb4d1820
2 changed files with 32 additions and 0 deletions

View File

@ -124,7 +124,15 @@ def l1_l2_regularizer(scale_l1=1.0, scale_l2=1.0, scope=None):
Raises:
ValueError: If scale is negative or if scale is not a float.
"""
if isinstance(scale_l1, numbers.Integral):
raise ValueError('scale_l1 cannot be an integer: %s' % (scale_l1,))
if isinstance(scale_l2, numbers.Integral):
raise ValueError('scale_l2 cannot be an integer: %s' % (scale_l2,))
scope = scope or 'l1_l2_regularizer'
if scale_l1 == 0.:
return l2_regularizer(scale_l2, scope)
if scale_l2 == 0.:
return l1_regularizer(scale_l1, scope)
return sum_regularizer([l1_regularizer(scale_l1),
l2_regularizer(scale_l2)],
scope=scope)

View File

@ -79,6 +79,30 @@ class RegularizerTest(test.TestCase):
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
def test_l1_l2_scale_l1Zero(self):
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor)
with self.test_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
def test_l1_l2_scale_l2Zero(self):
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor)
with self.test_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem, 5)
def test_l1_l2_scales_Zero(self):
shape = [5, 5, 5]
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(0.0, 0.0)(tensor)
self.assertEquals(loss, None)
def testL1L2RegularizerWithScope(self):
with self.test_session():
shape = [5, 5, 5]