We need to bring in the classes from advanced_activations if there are no custom objects specified. When no custom objects are specified, our module_objects/globals() in activations.deserialize() won't contain any advanced_activations.

This commit is contained in:
PiyushDatta 2020-05-07 05:54:17 -04:00 committed by piyushdatta
parent c8822f95b7
commit e9a7eec1dc
3 changed files with 20 additions and 1 deletions

View File

@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.keras.layers import advanced_activations
# b/123041942
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
@ -525,9 +526,18 @@ def deserialize(name, custom_objects=None):
ValueError: `Unknown activation function` if the input string does not
denote any defined Tensorflow activation function.
"""
globs = globals()
# only replace missing activations, when there are no custom objects
if custom_objects is None:
advanced_activations_globs = advanced_activations.get_globals()
for key,val in advanced_activations_globs.items():
if key not in globs:
globs[key] = val
return deserialize_keras_object(
name,
module_objects=globals(),
module_objects=globs,
custom_objects=custom_objects,
printable_module_name='activation function')

View File

@ -65,12 +65,19 @@ class KerasActivationsTest(test.TestCase, parameterized.TestCase):
activation = advanced_activations.LeakyReLU(alpha=0.1)
layer = core.Dense(3, activation=activation)
config = serialization.serialize(layer)
# with custom objects
deserialized_layer = serialization.deserialize(
config, custom_objects={'LeakyReLU': activation})
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__class__.__name__,
activation.__class__.__name__)
# without custom objects
deserialized_layer = serialization.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__class__.__name__,
activation.__class__.__name__)
def test_softmax(self):
x = backend.placeholder(ndim=2)

View File

@ -29,6 +29,8 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export
def get_globals():
return globals()
@keras_export('keras.layers.LeakyReLU')
class LeakyReLU(Layer):