Fix ReLU supports_masking
flag.
PiperOrigin-RevId: 343335081 Change-Id: I4a580e3a62630b26d22ad2028f1a096f626524cb
This commit is contained in:
parent
bf7642cd6d
commit
450baeee62
tensorflow/python/keras
@ -408,7 +408,7 @@ class ReLU(Layer):
|
||||
raise ValueError('threshold of Relu layer '
|
||||
'cannot be None. Required a float')
|
||||
|
||||
self.support_masking = True
|
||||
self.supports_masking = True
|
||||
if max_value is not None:
|
||||
max_value = K.cast_to_floatx(max_value)
|
||||
self.max_value = max_value
|
||||
|
@ -34,37 +34,44 @@ class AdvancedActivationsTest(keras_parameterized.TestCase):
|
||||
for alpha in [0., .5, -1.]:
|
||||
testing_utils.layer_test(keras.layers.LeakyReLU,
|
||||
kwargs={'alpha': alpha},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
|
||||
def test_prelu(self):
|
||||
testing_utils.layer_test(keras.layers.PReLU, kwargs={},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
|
||||
def test_prelu_share(self):
|
||||
testing_utils.layer_test(keras.layers.PReLU,
|
||||
kwargs={'shared_axes': 1},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
|
||||
def test_elu(self):
|
||||
for alpha in [0., .5, -1.]:
|
||||
testing_utils.layer_test(keras.layers.ELU,
|
||||
kwargs={'alpha': alpha},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
|
||||
def test_thresholded_relu(self):
|
||||
testing_utils.layer_test(keras.layers.ThresholdedReLU,
|
||||
kwargs={'theta': 0.5},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
|
||||
def test_softmax(self):
|
||||
testing_utils.layer_test(keras.layers.Softmax,
|
||||
kwargs={'axis': 1},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
|
||||
def test_relu(self):
|
||||
testing_utils.layer_test(keras.layers.ReLU,
|
||||
kwargs={'max_value': 10},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
x = keras.backend.ones((3, 4))
|
||||
if not context.executing_eagerly():
|
||||
# Test that we use `leaky_relu` when appropriate in graph mode.
|
||||
@ -80,7 +87,8 @@ class AdvancedActivationsTest(keras_parameterized.TestCase):
|
||||
ValueError, 'max_value of Relu layer cannot be negative value: -10'):
|
||||
testing_utils.layer_test(keras.layers.ReLU,
|
||||
kwargs={'max_value': -10},
|
||||
input_shape=(2, 3, 4))
|
||||
input_shape=(2, 3, 4),
|
||||
supports_masking=True)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'negative_slope of Relu layer cannot be negative value: -2'):
|
||||
|
@ -100,7 +100,8 @@ def layer_test(layer_cls,
|
||||
validate_training=True,
|
||||
adapt_data=None,
|
||||
custom_objects=None,
|
||||
test_harness=None):
|
||||
test_harness=None,
|
||||
supports_masking=None):
|
||||
"""Test routine for a layer with a single input and single output.
|
||||
|
||||
Arguments:
|
||||
@ -122,6 +123,8 @@ def layer_test(layer_cls,
|
||||
in the layer class. This is helpful for testing custom layers.
|
||||
test_harness: The Tensorflow test, if any, that this function is being
|
||||
called in.
|
||||
supports_masking: Optional boolean to check the `supports_masking` property
|
||||
of the layer. If None, the check will not be performed.
|
||||
|
||||
Returns:
|
||||
The output data (Numpy array) returned by the layer, for additional
|
||||
@ -165,6 +168,13 @@ def layer_test(layer_cls,
|
||||
kwargs = kwargs or {}
|
||||
layer = layer_cls(**kwargs)
|
||||
|
||||
if (supports_masking is not None
|
||||
and layer.supports_masking != supports_masking):
|
||||
raise AssertionError(
|
||||
'When testing layer %s, the `supports_masking` property is %r'
|
||||
'but expected to be %r.\nFull kwargs: %s' %
|
||||
(layer_cls.__name__, layer.supports_masking, supports_masking, kwargs))
|
||||
|
||||
# Test adapt, if data was passed.
|
||||
if adapt_data is not None:
|
||||
layer.adapt(adapt_data)
|
||||
|
Loading…
Reference in New Issue
Block a user