Merge pull request #40073 from howl-anderson:bugfix/tf_issue_37196

PiperOrigin-RevId: 315710318
Change-Id: I4b0c64a217cdc224014778466117e54ea0c5d62c
This commit is contained in:
TensorFlower Gardener 2020-06-10 10:00:46 -07:00
commit 05ba052c9f

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import six
from tensorflow.python.keras import backend
@ -28,6 +30,19 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export
def _check_penalty_number(x):
"""check penalty number availability, raise ValueError if failed"""
if not isinstance(x, (float, int)):
raise ValueError(('Value: {} is not a valid regularization penalty number, '
'expected an int or float value').format(x))
if math.isinf(x) or math.isnan(x):
raise ValueError(
('Value: {} is not a valid regularization penalty number, '
'a positive/negative infinity or NaN is not a property value'
).format(x))
@keras_export('keras.regularizers.Regularizer')
class Regularizer(object):
"""Regularizer base class.
@ -215,6 +230,9 @@ class L1L2(Regularizer):
"""
def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name
_check_penalty_number(l1)
_check_penalty_number(l2)
self.l1 = backend.cast_to_floatx(l1)
self.l2 = backend.cast_to_floatx(l2)
@ -251,6 +269,9 @@ class L1(Regularizer):
l1 = kwargs.pop('l', l1) # Backwards compatibility
if kwargs:
raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
_check_penalty_number(l1)
self.l1 = backend.cast_to_floatx(l1)
def __call__(self, x):
@ -281,6 +302,9 @@ class L2(Regularizer):
l2 = kwargs.pop('l', l2) # Backwards compatibility
if kwargs:
raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
_check_penalty_number(l2)
self.l2 = backend.cast_to_floatx(l2)
def __call__(self, x):