Merge pull request #22478 from yongtang:22470-tf.keras.regularizers
PiperOrigin-RevId: 221948116
This commit is contained in:
commit
209c7dae93
@ -26,6 +26,7 @@ from tensorflow.python.keras import testing_utils
|
|||||||
from tensorflow.python.layers import core as tf_core_layers
|
from tensorflow.python.layers import core as tf_core_layers
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import rnn_cell
|
from tensorflow.python.ops import rnn_cell
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -312,6 +313,15 @@ class KerasIntegrationTest(test.TestCase):
|
|||||||
verbose=0)
|
verbose=0)
|
||||||
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
||||||
|
|
||||||
|
def test_regularizers_with_get_variable(self):
|
||||||
|
# Test case for GitHub issue 22470.
|
||||||
|
with self.cached_session():
|
||||||
|
v = variable_scope.get_variable(
|
||||||
|
'v',
|
||||||
|
shape=[4, 4],
|
||||||
|
initializer=keras.initializers.glorot_uniform(),
|
||||||
|
regularizer=keras.regularizers.l2(0.))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||||
@ -54,12 +55,14 @@ class L1L2(Regularizer):
|
|||||||
self.l2 = K.cast_to_floatx(l2)
|
self.l2 = K.cast_to_floatx(l2)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
regularization = 0.
|
if self.l1 or self.l2:
|
||||||
if self.l1:
|
regularization = ops.convert_to_tensor(0., dtype=K.floatx())
|
||||||
regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
|
if self.l1:
|
||||||
if self.l2:
|
regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
|
||||||
regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
|
if self.l2:
|
||||||
return regularization
|
regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
|
||||||
|
return regularization
|
||||||
|
return None
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {'l1': float(self.l1), 'l2': float(self.l2)}
|
return {'l1': float(self.l1), 'l2': float(self.l2)}
|
||||||
|
Loading…
Reference in New Issue
Block a user