From 8df4f0ce0acc838dd252ae504d44676c35cb6a6b Mon Sep 17 00:00:00 2001 From: Paul Andrey Date: Fri, 13 Mar 2020 16:15:59 +0100 Subject: [PATCH] Fixed `Wrapper`'s `get_config` and `from_config`. * `get_config`: properly serialize the wrapped layer. This notably fixes issues when wrapping custom layers that have been registered using `tf.keras.utils.register_keras_serializable`. * `from_config`: properly copy input config to avoid side effects. --- tensorflow/python/keras/layers/wrappers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 97b51501b18..3efbdac8729 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -68,10 +68,7 @@ class Wrapper(Layer): def get_config(self): config = { - 'layer': { - 'class_name': self.layer.__class__.__name__, - 'config': self.layer.get_config() - } + 'layer': generic_utils.serialize_keras_object(self.layer) } base_config = super(Wrapper, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -80,7 +77,7 @@ class Wrapper(Layer): def from_config(cls, config, custom_objects=None): from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top # Avoid mutating the input dict - config = config.copy() + config = copy.deepcopy(config) layer = deserialize_layer( config.pop('layer'), custom_objects=custom_objects) return cls(layer, **config)