Merge pull request #22478 from yongtang:22470-tf.keras.regularizers
PiperOrigin-RevId: 221948116
This commit is contained in:
commit
209c7dae93
@ -26,6 +26,7 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.layers import core as tf_core_layers
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -312,6 +313,15 @@ class KerasIntegrationTest(test.TestCase):
|
||||
verbose=0)
|
||||
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
||||
|
||||
def test_regularizers_with_get_variable(self):
|
||||
# Test case for GitHub issue 22470.
|
||||
with self.cached_session():
|
||||
v = variable_scope.get_variable(
|
||||
'v',
|
||||
shape=[4, 4],
|
||||
initializer=keras.initializers.glorot_uniform(),
|
||||
regularizer=keras.regularizers.l2(0.))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
@ -54,12 +55,14 @@ class L1L2(Regularizer):
|
||||
self.l2 = K.cast_to_floatx(l2)
|
||||
|
||||
def __call__(self, x):
|
||||
regularization = 0.
|
||||
if self.l1 or self.l2:
|
||||
regularization = ops.convert_to_tensor(0., dtype=K.floatx())
|
||||
if self.l1:
|
||||
regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
|
||||
if self.l2:
|
||||
regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
|
||||
return regularization
|
||||
return None
|
||||
|
||||
def get_config(self):
|
||||
return {'l1': float(self.l1), 'l2': float(self.l2)}
|
||||
|
Loading…
Reference in New Issue
Block a user