We need to bring in the classes from advanced_activations if there are no custom objects specified. When no custom objects are specified, our module_objects/globals() in activations.deserialize() won't contain any advanced_activations.
This commit is contained in:
parent
c8822f95b7
commit
e9a7eec1dc
tensorflow/python/keras
@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.util import dispatch
|
from tensorflow.python.util import dispatch
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
from tensorflow.python.keras.layers import advanced_activations
|
||||||
|
|
||||||
# b/123041942
|
# b/123041942
|
||||||
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
|
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
|
||||||
@ -525,9 +526,18 @@ def deserialize(name, custom_objects=None):
|
|||||||
ValueError: `Unknown activation function` if the input string does not
|
ValueError: `Unknown activation function` if the input string does not
|
||||||
denote any defined Tensorflow activation function.
|
denote any defined Tensorflow activation function.
|
||||||
"""
|
"""
|
||||||
|
globs = globals()
|
||||||
|
|
||||||
|
# only replace missing activations, when there are no custom objects
|
||||||
|
if custom_objects is None:
|
||||||
|
advanced_activations_globs = advanced_activations.get_globals()
|
||||||
|
for key,val in advanced_activations_globs.items():
|
||||||
|
if key not in globs:
|
||||||
|
globs[key] = val
|
||||||
|
|
||||||
return deserialize_keras_object(
|
return deserialize_keras_object(
|
||||||
name,
|
name,
|
||||||
module_objects=globals(),
|
module_objects=globs,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
printable_module_name='activation function')
|
printable_module_name='activation function')
|
||||||
|
|
||||||
|
@ -65,12 +65,19 @@ class KerasActivationsTest(test.TestCase, parameterized.TestCase):
|
|||||||
activation = advanced_activations.LeakyReLU(alpha=0.1)
|
activation = advanced_activations.LeakyReLU(alpha=0.1)
|
||||||
layer = core.Dense(3, activation=activation)
|
layer = core.Dense(3, activation=activation)
|
||||||
config = serialization.serialize(layer)
|
config = serialization.serialize(layer)
|
||||||
|
# with custom objects
|
||||||
deserialized_layer = serialization.deserialize(
|
deserialized_layer = serialization.deserialize(
|
||||||
config, custom_objects={'LeakyReLU': activation})
|
config, custom_objects={'LeakyReLU': activation})
|
||||||
self.assertEqual(deserialized_layer.__class__.__name__,
|
self.assertEqual(deserialized_layer.__class__.__name__,
|
||||||
layer.__class__.__name__)
|
layer.__class__.__name__)
|
||||||
self.assertEqual(deserialized_layer.activation.__class__.__name__,
|
self.assertEqual(deserialized_layer.activation.__class__.__name__,
|
||||||
activation.__class__.__name__)
|
activation.__class__.__name__)
|
||||||
|
# without custom objects
|
||||||
|
deserialized_layer = serialization.deserialize(config)
|
||||||
|
self.assertEqual(deserialized_layer.__class__.__name__,
|
||||||
|
layer.__class__.__name__)
|
||||||
|
self.assertEqual(deserialized_layer.activation.__class__.__name__,
|
||||||
|
activation.__class__.__name__)
|
||||||
|
|
||||||
def test_softmax(self):
|
def test_softmax(self):
|
||||||
x = backend.placeholder(ndim=2)
|
x = backend.placeholder(ndim=2)
|
||||||
|
@ -29,6 +29,8 @@ from tensorflow.python.keras.utils import tf_utils
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
def get_globals():
|
||||||
|
return globals()
|
||||||
|
|
||||||
@keras_export('keras.layers.LeakyReLU')
|
@keras_export('keras.layers.LeakyReLU')
|
||||||
class LeakyReLU(Layer):
|
class LeakyReLU(Layer):
|
||||||
|
Loading…
Reference in New Issue
Block a user