Fix Lambda layer deserialization polluting core module.
PiperOrigin-RevId: 314978364 Change-Id: Iae42839140cdb0853f91c77a3a1c7ff16d391e02
This commit is contained in:
parent
d66ae5d65f
commit
96ae6b3eec
|
@ -1025,7 +1025,7 @@ class Lambda(Layer):
|
||||||
def _parse_function_from_config(
|
def _parse_function_from_config(
|
||||||
cls, config, custom_objects, func_attr_name, module_attr_name,
|
cls, config, custom_objects, func_attr_name, module_attr_name,
|
||||||
func_type_attr_name):
|
func_type_attr_name):
|
||||||
globs = globals()
|
globs = globals().copy()
|
||||||
module = config.pop(module_attr_name, None)
|
module = config.pop(module_attr_name, None)
|
||||||
if module in sys.modules:
|
if module in sys.modules:
|
||||||
globs.update(sys.modules[module].__dict__)
|
globs.update(sys.modules[module].__dict__)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
@ -274,6 +275,12 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
||||||
expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]])
|
expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]])
|
||||||
self.assertAllClose(out, expected_out)
|
self.assertAllClose(out, expected_out)
|
||||||
|
|
||||||
|
def test_lambda_deserialization_does_not_pollute_core(self):
|
||||||
|
layer = keras.layers.Lambda(lambda x: x + 1)
|
||||||
|
config = layer.get_config()
|
||||||
|
keras.layers.Lambda.from_config(config)
|
||||||
|
self.assertNotIn(self.__class__.__name__, dir(core))
|
||||||
|
|
||||||
|
|
||||||
class TestStatefulLambda(keras_parameterized.TestCase):
|
class TestStatefulLambda(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue