Merge pull request #40073 from howl-anderson:bugfix/tf_issue_37196
PiperOrigin-RevId: 315710318 Change-Id: I4b0c64a217cdc224014778466117e54ea0c5d62c
This commit is contained in:
commit
05ba052c9f
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.keras import backend
|
||||
@ -28,6 +30,19 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
def _check_penalty_number(x):
|
||||
"""check penalty number availability, raise ValueError if failed"""
|
||||
if not isinstance(x, (float, int)):
|
||||
raise ValueError(('Value: {} is not a valid regularization penalty number, '
|
||||
'expected an int or float value').format(x))
|
||||
|
||||
if math.isinf(x) or math.isnan(x):
|
||||
raise ValueError(
|
||||
('Value: {} is not a valid regularization penalty number, '
|
||||
'a positive/negative infinity or NaN is not a property value'
|
||||
).format(x))
|
||||
|
||||
|
||||
@keras_export('keras.regularizers.Regularizer')
|
||||
class Regularizer(object):
|
||||
"""Regularizer base class.
|
||||
@ -215,6 +230,9 @@ class L1L2(Regularizer):
|
||||
"""
|
||||
|
||||
def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name
|
||||
_check_penalty_number(l1)
|
||||
_check_penalty_number(l2)
|
||||
|
||||
self.l1 = backend.cast_to_floatx(l1)
|
||||
self.l2 = backend.cast_to_floatx(l2)
|
||||
|
||||
@ -251,6 +269,9 @@ class L1(Regularizer):
|
||||
l1 = kwargs.pop('l', l1) # Backwards compatibility
|
||||
if kwargs:
|
||||
raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
|
||||
|
||||
_check_penalty_number(l1)
|
||||
|
||||
self.l1 = backend.cast_to_floatx(l1)
|
||||
|
||||
def __call__(self, x):
|
||||
@ -281,6 +302,9 @@ class L2(Regularizer):
|
||||
l2 = kwargs.pop('l', l2) # Backwards compatibility
|
||||
if kwargs:
|
||||
raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
|
||||
|
||||
_check_penalty_number(l2)
|
||||
|
||||
self.l2 = backend.cast_to_floatx(l2)
|
||||
|
||||
def __call__(self, x):
|
||||
|
Loading…
Reference in New Issue
Block a user