From 8df4f0ce0acc838dd252ae504d44676c35cb6a6b Mon Sep 17 00:00:00 2001 From: Paul Andrey Date: Fri, 13 Mar 2020 16:15:59 +0100 Subject: [PATCH 1/2] Fixed `Wrapper`'s `get_config` and `from_config`. * `get_config`: properly serialize the wrapped layer. This notably fixes issues when wrapping custom layers that have been registered using `tf.keras.utils.register_keras_serializable`. * `from_config`: properly copy input config to avoid side effects. --- tensorflow/python/keras/layers/wrappers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 97b51501b18..3efbdac8729 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -68,10 +68,7 @@ class Wrapper(Layer): def get_config(self): config = { - 'layer': { - 'class_name': self.layer.__class__.__name__, - 'config': self.layer.get_config() - } + 'layer': generic_utils.serialize_keras_object(self.layer) } base_config = super(Wrapper, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -80,7 +77,7 @@ class Wrapper(Layer): def from_config(cls, config, custom_objects=None): from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top # Avoid mutating the input dict - config = config.copy() + config = copy.deepcopy(config) layer = deserialize_layer( config.pop('layer'), custom_objects=custom_objects) return cls(layer, **config) From bb463f8fa4791dff2ae5a14bd55b378941059292 Mon Sep 17 00:00:00 2001 From: Paul Andrey Date: Wed, 25 Mar 2020 15:58:00 +0100 Subject: [PATCH 2/2] Fixed `Bidirectional` serialization. * Fixed `Bidirectional` de-serialization by skipping super call, as `copy.deepcopy` cannot handle instantiated layers within a config dictionary. * On the side, improved `backward_layer` attribute's serialization scheme, so that registered custom layers may be re-instantiated. --- tensorflow/python/keras/layers/wrappers.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 3efbdac8729..bfc564afa27 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -423,7 +423,8 @@ class Bidirectional(Wrapper): # Keep the custom backward layer config, so that we can save it later. The # layer's name might be updated below with prefix 'backward_', and we want # to preserve the original config. - self._backward_layer_config = backward_layer.get_config() + self._backward_layer_config = generic_utils.serialize_keras_object( + backward_layer) self.forward_layer._name = 'forward_' + self.forward_layer.name self.backward_layer._name = 'backward_' + self.backward_layer.name @@ -717,26 +718,26 @@ class Bidirectional(Wrapper): config['num_constants'] = self._num_constants if hasattr(self, '_backward_layer_config'): - config['backward_layer'] = { - 'class_name': self.backward_layer.__class__.__name__, - 'config': self._backward_layer_config, - } + config['backward_layer'] = self._backward_layer_config base_config = super(Bidirectional, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): # Instead of updating the input, create a copy and use that. - config = config.copy() + config = copy.deepcopy(config) num_constants = config.pop('num_constants', 0) + # Handle forward layer instantiation (as would parent class). + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + config['layer'] = deserialize_layer( + config['layer'], custom_objects=custom_objects) + # Handle (optional) backward layer instantiation. backward_layer_config = config.pop('backward_layer', None) if backward_layer_config is not None: - from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top backward_layer = deserialize_layer( backward_layer_config, custom_objects=custom_objects) config['backward_layer'] = backward_layer - - layer = super(Bidirectional, cls).from_config(config, - custom_objects=custom_objects) + # Instantiate the wrapper, adjust it and return it. + layer = cls(**config) layer._num_constants = num_constants return layer