diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py index 3c0f73b1c3a..25ca9e69e2e 100644 --- a/tensorflow/python/keras/integration_test.py +++ b/tensorflow/python/keras/integration_test.py @@ -26,6 +26,7 @@ from tensorflow.python.keras import testing_utils from tensorflow.python.layers import core as tf_core_layers from tensorflow.python.ops import nn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -312,6 +313,15 @@ class KerasIntegrationTest(test.TestCase): verbose=0) self.assertGreater(history.history['val_acc'][-1], 0.7) + def test_regularizers_with_get_variable(self): + # Test case for GitHub issue 22470. + with self.cached_session(): + v = variable_scope.get_variable( + 'v', + shape=[4, 4], + initializer=keras.initializers.glorot_uniform(), + regularizer=keras.regularizers.l2(0.)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/regularizers.py b/tensorflow/python/keras/regularizers.py index 28b6ad4c65a..cbcdae214f9 100644 --- a/tensorflow/python/keras/regularizers.py +++ b/tensorflow/python/keras/regularizers.py @@ -20,6 +20,7 @@ from __future__ import print_function import six +from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object @@ -54,12 +55,14 @@ class L1L2(Regularizer): self.l2 = K.cast_to_floatx(l2) def __call__(self, x): - regularization = 0. - if self.l1: - regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x)) - if self.l2: - regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x)) - return regularization + if self.l1 or self.l2: + regularization = ops.convert_to_tensor(0., dtype=K.floatx()) + if self.l1: + regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x)) + if self.l2: + regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x)) + return regularization + return None def get_config(self): return {'l1': float(self.l1), 'l2': float(self.l2)}