From 5350da6b746f076fdd1568d2ac2dd2c86f7d2512 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 8 Feb 2021 20:58:48 +0000 Subject: [PATCH 1/2] Thrown out ValueError when alpha=None is passed for tf.keras.layers.LeakyReLU This PR tries to address the issue raised in 46993 where incorrect nan value is returned when alpha=None is passed for tf.keras.layers.LeakyReLU. The nan could be misleading to users. This PR address the issue and throw out ValueError instead. This PR fixes 46993. Signed-off-by: Yong Tang --- tensorflow/python/keras/layers/advanced_activations.py | 3 +++ .../python/keras/layers/advanced_activations_test.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 6a0ae73b9b7..5e0e5c74980 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -71,6 +71,9 @@ class LeakyReLU(Layer): def __init__(self, alpha=0.3, **kwargs): super(LeakyReLU, self).__init__(**kwargs) + if alpha is None: + raise ValueError('alpha of leaky Relu layer ' + 'cannot be None. Required a float') self.supports_masking = True self.alpha = K.cast_to_floatx(alpha) diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py index 96a0c217cac..ff3d8affdbf 100644 --- a/tensorflow/python/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -108,6 +108,14 @@ class AdvancedActivationsTest(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly()) model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) + def test_leaky_relu_with_invalid_alpha(self): + with self.assertRaisesRegex( + ValueError, 'alpha of leaky Relu layer cannot be None'): + testing_utils.layer_test(keras.layers.LeakyReLU, + kwargs={'alpha': None}, + input_shape=(2, 3, 4), + supports_masking=True) + if __name__ == '__main__': test.main() From a62609d054deca454923731cf1f4daf8c5c331fb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 8 Feb 2021 21:09:48 +0000 Subject: [PATCH 2/2] Thrown out ValueError if alpha is None for ELU Signed-off-by: Yong Tang --- tensorflow/python/keras/layers/advanced_activations.py | 3 +++ .../python/keras/layers/advanced_activations_test.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 5e0e5c74980..f73a12f21e4 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -209,6 +209,9 @@ class ELU(Layer): def __init__(self, alpha=1.0, **kwargs): super(ELU, self).__init__(**kwargs) + if alpha is None: + raise ValueError('alpha of ELU layer ' + 'cannot be None. Required a float') self.supports_masking = True self.alpha = K.cast_to_floatx(alpha) diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py index ff3d8affdbf..4d2c360c8d7 100644 --- a/tensorflow/python/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -109,6 +109,7 @@ class AdvancedActivationsTest(keras_parameterized.TestCase): model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) def test_leaky_relu_with_invalid_alpha(self): + # Test case for GitHub issue 46993. with self.assertRaisesRegex( ValueError, 'alpha of leaky Relu layer cannot be None'): testing_utils.layer_test(keras.layers.LeakyReLU, @@ -116,6 +117,15 @@ class AdvancedActivationsTest(keras_parameterized.TestCase): input_shape=(2, 3, 4), supports_masking=True) + def test_leaky_elu_with_invalid_alpha(self): + # Test case for GitHub issue 46993. + with self.assertRaisesRegex( + ValueError, 'alpha of ELU layer cannot be None'): + testing_utils.layer_test(keras.layers.ELU, + kwargs={'alpha': None}, + input_shape=(2, 3, 4), + supports_masking=True) + if __name__ == '__main__': test.main()