Merge pull request #22478 from yongtang:22470-tf.keras.regularizers

PiperOrigin-RevId: 221948116
This commit is contained in:
TensorFlower Gardener 2018-11-17 18:27:06 -08:00
commit 209c7dae93
2 changed files with 19 additions and 6 deletions

View File

@ -26,6 +26,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.layers import core as tf_core_layers from tensorflow.python.layers import core as tf_core_layers
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -312,6 +313,15 @@ class KerasIntegrationTest(test.TestCase):
verbose=0) verbose=0)
self.assertGreater(history.history['val_acc'][-1], 0.7) self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_regularizers_with_get_variable(self):
# Test case for GitHub issue 22470.
with self.cached_session():
v = variable_scope.get_variable(
'v',
shape=[4, 4],
initializer=keras.initializers.glorot_uniform(),
regularizer=keras.regularizers.l2(0.))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import six import six
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
@ -54,12 +55,14 @@ class L1L2(Regularizer):
self.l2 = K.cast_to_floatx(l2) self.l2 = K.cast_to_floatx(l2)
def __call__(self, x): def __call__(self, x):
regularization = 0. if self.l1 or self.l2:
if self.l1: regularization = ops.convert_to_tensor(0., dtype=K.floatx())
regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x)) if self.l1:
if self.l2: regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x)) if self.l2:
return regularization regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
return regularization
return None
def get_config(self): def get_config(self):
return {'l1': float(self.l1), 'l2': float(self.l2)} return {'l1': float(self.l1), 'l2': float(self.l2)}