Merge pull request #39252 from PiyushDatta:datapi_tflow

PiperOrigin-RevId: 338686808
Change-Id: I9faf849a3dce26d452f506c569caefd67a01e838
This commit is contained in:
TensorFlower Gardener 2020-10-23 09:26:03 -07:00
commit 080f49a6a5
4 changed files with 22 additions and 1 deletions

View File

@ -130,6 +130,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":backend",
"//tensorflow/python/keras/layers:advanced_activations",
"//tensorflow/python/keras/utils:engine_utils",
],
)

View File

@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.keras.layers import advanced_activations
# b/123041942
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
@ -529,9 +530,17 @@ def deserialize(name, custom_objects=None):
ValueError: `Unknown activation function` if the input string does not
denote any defined Tensorflow activation function.
"""
globs = globals()
# only replace missing activations
advanced_activations_globs = advanced_activations.get_globals()
for key, val in advanced_activations_globs.items():
if key not in globs:
globs[key] = val
return deserialize_keras_object(
name,
module_objects=globals(),
module_objects=globs,
custom_objects=custom_objects,
printable_module_name='activation function')

View File

@ -65,12 +65,19 @@ class KerasActivationsTest(test.TestCase, parameterized.TestCase):
activation = advanced_activations.LeakyReLU(alpha=0.1)
layer = core.Dense(3, activation=activation)
config = serialization.serialize(layer)
# with custom objects
deserialized_layer = serialization.deserialize(
config, custom_objects={'LeakyReLU': activation})
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__class__.__name__,
activation.__class__.__name__)
# without custom objects
deserialized_layer = serialization.deserialize(config)
self.assertEqual(deserialized_layer.__class__.__name__,
layer.__class__.__name__)
self.assertEqual(deserialized_layer.activation.__class__.__name__,
activation.__class__.__name__)
def test_softmax(self):
x = backend.placeholder(ndim=2)

View File

@ -30,6 +30,10 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export
def get_globals():
return globals()
@keras_export('keras.layers.LeakyReLU')
class LeakyReLU(Layer):
"""Leaky version of a Rectified Linear Unit.