From 96ae6b3eec3c0fd0090a57b9a4fbe10790f10790 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Jun 2020 12:45:37 -0700 Subject: [PATCH] Fix Lambda layer deserialization polluting core module. PiperOrigin-RevId: 314978364 Change-Id: Iae42839140cdb0853f91c77a3a1c7ff16d391e02 --- tensorflow/python/keras/layers/core.py | 2 +- tensorflow/python/keras/layers/core_test.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 61820afdf2a..d22c3ceea45 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -1025,7 +1025,7 @@ class Lambda(Layer): def _parse_function_from_config( cls, config, custom_objects, func_attr_name, module_attr_name, func_type_attr_name): - globs = globals() + globs = globals().copy() module = config.pop(module_attr_name, None) if module in sys.modules: globs.update(sys.modules[module].__dict__) diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 70ad63c17eb..15cd8157c0c 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.layers import core from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -274,6 +275,12 @@ class LambdaLayerTest(keras_parameterized.TestCase): expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]]) self.assertAllClose(out, expected_out) + def test_lambda_deserialization_does_not_pollute_core(self): + layer = keras.layers.Lambda(lambda x: x + 1) + config = layer.get_config() + keras.layers.Lambda.from_config(config) + self.assertNotIn(self.__class__.__name__, dir(core)) + class TestStatefulLambda(keras_parameterized.TestCase):