Add support of scale_l1 == 0. or scale_l2 == 0 to l1_l2_regularizer.
Added tests. PiperOrigin-RevId: 158179790
This commit is contained in:
parent
b65eb3f9b5
commit
96cb4d1820
@ -124,7 +124,15 @@ def l1_l2_regularizer(scale_l1=1.0, scale_l2=1.0, scope=None):
|
||||
Raises:
|
||||
ValueError: If scale is negative or if scale is not a float.
|
||||
"""
|
||||
if isinstance(scale_l1, numbers.Integral):
|
||||
raise ValueError('scale_l1 cannot be an integer: %s' % (scale_l1,))
|
||||
if isinstance(scale_l2, numbers.Integral):
|
||||
raise ValueError('scale_l2 cannot be an integer: %s' % (scale_l2,))
|
||||
scope = scope or 'l1_l2_regularizer'
|
||||
if scale_l1 == 0.:
|
||||
return l2_regularizer(scale_l2, scope)
|
||||
if scale_l2 == 0.:
|
||||
return l1_regularizer(scale_l1, scope)
|
||||
return sum_regularizer([l1_regularizer(scale_l1),
|
||||
l2_regularizer(scale_l2)],
|
||||
scope=scope)
|
||||
|
||||
@ -79,6 +79,30 @@ class RegularizerTest(test.TestCase):
|
||||
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
|
||||
self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
|
||||
|
||||
def test_l1_l2_scale_l1Zero(self):
|
||||
shape = [5, 5, 5]
|
||||
num_elem = 5 * 5 * 5
|
||||
tensor = constant_op.constant(1.0, shape=shape)
|
||||
loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor)
|
||||
with self.test_session():
|
||||
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
|
||||
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
|
||||
|
||||
def test_l1_l2_scale_l2Zero(self):
|
||||
shape = [5, 5, 5]
|
||||
num_elem = 5 * 5 * 5
|
||||
tensor = constant_op.constant(1.0, shape=shape)
|
||||
loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor)
|
||||
with self.test_session():
|
||||
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
|
||||
self.assertAlmostEqual(loss.eval(), num_elem, 5)
|
||||
|
||||
def test_l1_l2_scales_Zero(self):
|
||||
shape = [5, 5, 5]
|
||||
tensor = constant_op.constant(1.0, shape=shape)
|
||||
loss = regularizers.l1_l2_regularizer(0.0, 0.0)(tensor)
|
||||
self.assertEquals(loss, None)
|
||||
|
||||
def testL1L2RegularizerWithScope(self):
|
||||
with self.test_session():
|
||||
shape = [5, 5, 5]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user