Fix Lambda layer deserialization polluting core module.

PiperOrigin-RevId: 314978364
Change-Id: Iae42839140cdb0853f91c77a3a1c7ff16d391e02
This commit is contained in:
A. Unique TensorFlower 2020-06-05 12:45:37 -07:00 committed by TensorFlower Gardener
parent d66ae5d65f
commit 96ae6b3eec
2 changed files with 8 additions and 1 deletions

View File

@ -1025,7 +1025,7 @@ class Lambda(Layer):
def _parse_function_from_config( def _parse_function_from_config(
cls, config, custom_objects, func_attr_name, module_attr_name, cls, config, custom_objects, func_attr_name, module_attr_name,
func_type_attr_name): func_type_attr_name):
globs = globals() globs = globals().copy()
module = config.pop(module_attr_name, None) module = config.pop(module_attr_name, None)
if module in sys.modules: if module in sys.modules:
globs.update(sys.modules[module].__dict__) globs.update(sys.modules[module].__dict__)

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -274,6 +275,12 @@ class LambdaLayerTest(keras_parameterized.TestCase):
expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]]) expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]])
self.assertAllClose(out, expected_out) self.assertAllClose(out, expected_out)
def test_lambda_deserialization_does_not_pollute_core(self):
layer = keras.layers.Lambda(lambda x: x + 1)
config = layer.get_config()
keras.layers.Lambda.from_config(config)
self.assertNotIn(self.__class__.__name__, dir(core))
class TestStatefulLambda(keras_parameterized.TestCase): class TestStatefulLambda(keras_parameterized.TestCase):