Merge pull request #39252 from PiyushDatta:datapi_tflow
PiperOrigin-RevId: 338686808 Change-Id: I9faf849a3dce26d452f506c569caefd67a01e838
This commit is contained in:
commit
080f49a6a5
@ -130,6 +130,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python/keras/layers:advanced_activations",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
|
||||
# b/123041942
|
||||
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
|
||||
@ -529,9 +530,17 @@ def deserialize(name, custom_objects=None):
|
||||
ValueError: `Unknown activation function` if the input string does not
|
||||
denote any defined Tensorflow activation function.
|
||||
"""
|
||||
globs = globals()
|
||||
|
||||
# only replace missing activations
|
||||
advanced_activations_globs = advanced_activations.get_globals()
|
||||
for key, val in advanced_activations_globs.items():
|
||||
if key not in globs:
|
||||
globs[key] = val
|
||||
|
||||
return deserialize_keras_object(
|
||||
name,
|
||||
module_objects=globals(),
|
||||
module_objects=globs,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='activation function')
|
||||
|
||||
|
@ -65,12 +65,19 @@ class KerasActivationsTest(test.TestCase, parameterized.TestCase):
|
||||
activation = advanced_activations.LeakyReLU(alpha=0.1)
|
||||
layer = core.Dense(3, activation=activation)
|
||||
config = serialization.serialize(layer)
|
||||
# with custom objects
|
||||
deserialized_layer = serialization.deserialize(
|
||||
config, custom_objects={'LeakyReLU': activation})
|
||||
self.assertEqual(deserialized_layer.__class__.__name__,
|
||||
layer.__class__.__name__)
|
||||
self.assertEqual(deserialized_layer.activation.__class__.__name__,
|
||||
activation.__class__.__name__)
|
||||
# without custom objects
|
||||
deserialized_layer = serialization.deserialize(config)
|
||||
self.assertEqual(deserialized_layer.__class__.__name__,
|
||||
layer.__class__.__name__)
|
||||
self.assertEqual(deserialized_layer.activation.__class__.__name__,
|
||||
activation.__class__.__name__)
|
||||
|
||||
def test_softmax(self):
|
||||
x = backend.placeholder(ndim=2)
|
||||
|
@ -30,6 +30,10 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
def get_globals():
|
||||
return globals()
|
||||
|
||||
|
||||
@keras_export('keras.layers.LeakyReLU')
|
||||
class LeakyReLU(Layer):
|
||||
"""Leaky version of a Rectified Linear Unit.
|
||||
|
Loading…
Reference in New Issue
Block a user