From e9a7eec1dc6f2e2e72f468c167b93795ff60544b Mon Sep 17 00:00:00 2001 From: PiyushDatta <piyushdattaca@gmail.com> Date: Thu, 7 May 2020 05:54:17 -0400 Subject: [PATCH] We need to bring in the classes from advanced_activations if there are no custom objects specified. When no custom objects are specified, our module_objects/globals() in activations.deserialize() won't contain any advanced_activations. --- tensorflow/python/keras/activations.py | 12 +++++++++++- tensorflow/python/keras/activations_test.py | 7 +++++++ .../python/keras/layers/advanced_activations.py | 2 ++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index fe0bf5977f9..28d6f18dcf8 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import keras_export +from tensorflow.python.keras.layers import advanced_activations # b/123041942 # In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras @@ -525,9 +526,18 @@ def deserialize(name, custom_objects=None): ValueError: `Unknown activation function` if the input string does not denote any defined Tensorflow activation function. """ + globs = globals() + + # only replace missing activations, when there are no custom objects + if custom_objects is None: + advanced_activations_globs = advanced_activations.get_globals() + for key,val in advanced_activations_globs.items(): + if key not in globs: + globs[key] = val + return deserialize_keras_object( name, - module_objects=globals(), + module_objects=globs, custom_objects=custom_objects, printable_module_name='activation function') diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py index ddd3863a3f6..e2bdec0dd45 100644 --- a/tensorflow/python/keras/activations_test.py +++ b/tensorflow/python/keras/activations_test.py @@ -65,12 +65,19 @@ class KerasActivationsTest(test.TestCase, parameterized.TestCase): activation = advanced_activations.LeakyReLU(alpha=0.1) layer = core.Dense(3, activation=activation) config = serialization.serialize(layer) + # with custom objects deserialized_layer = serialization.deserialize( config, custom_objects={'LeakyReLU': activation}) self.assertEqual(deserialized_layer.__class__.__name__, layer.__class__.__name__) self.assertEqual(deserialized_layer.activation.__class__.__name__, activation.__class__.__name__) + # without custom objects + deserialized_layer = serialization.deserialize(config) + self.assertEqual(deserialized_layer.__class__.__name__, + layer.__class__.__name__) + self.assertEqual(deserialized_layer.activation.__class__.__name__, + activation.__class__.__name__) def test_softmax(self): x = backend.placeholder(ndim=2) diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index e4323b45dc4..e9ce23654fd 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -29,6 +29,8 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import keras_export +def get_globals(): + return globals() @keras_export('keras.layers.LeakyReLU') class LeakyReLU(Layer):