Fix Lambda layer deserialization polluting core module.
PiperOrigin-RevId: 314978364 Change-Id: Iae42839140cdb0853f91c77a3a1c7ff16d391e02
This commit is contained in:
parent
d66ae5d65f
commit
96ae6b3eec
|
@ -1025,7 +1025,7 @@ class Lambda(Layer):
|
|||
def _parse_function_from_config(
|
||||
cls, config, custom_objects, func_attr_name, module_attr_name,
|
||||
func_type_attr_name):
|
||||
globs = globals()
|
||||
globs = globals().copy()
|
||||
module = config.pop(module_attr_name, None)
|
||||
if module in sys.modules:
|
||||
globs.update(sys.modules[module].__dict__)
|
||||
|
|
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
@ -274,6 +275,12 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
|||
expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]])
|
||||
self.assertAllClose(out, expected_out)
|
||||
|
||||
def test_lambda_deserialization_does_not_pollute_core(self):
|
||||
layer = keras.layers.Lambda(lambda x: x + 1)
|
||||
config = layer.get_config()
|
||||
keras.layers.Lambda.from_config(config)
|
||||
self.assertNotIn(self.__class__.__name__, dir(core))
|
||||
|
||||
|
||||
class TestStatefulLambda(keras_parameterized.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue