Fixed Bidirectional serialization.

* Fixed `Bidirectional` de-serialization by skipping super call, as 
`copy.deepcopy` cannot handle instantiated layers within a config 
dictionary.
* On the side, improved `backward_layer` attribute's serialization 
scheme, so that registered custom layers may be re-instantiated.
This commit is contained in:
Paul Andrey 2020-03-25 15:58:00 +01:00
parent 8df4f0ce0a
commit bb463f8fa4

View File

@ -423,7 +423,8 @@ class Bidirectional(Wrapper):
# Keep the custom backward layer config, so that we can save it later. The
# layer's name might be updated below with prefix 'backward_', and we want
# to preserve the original config.
self._backward_layer_config = backward_layer.get_config()
self._backward_layer_config = generic_utils.serialize_keras_object(
backward_layer)
self.forward_layer._name = 'forward_' + self.forward_layer.name
self.backward_layer._name = 'backward_' + self.backward_layer.name
@ -717,26 +718,26 @@ class Bidirectional(Wrapper):
config['num_constants'] = self._num_constants
if hasattr(self, '_backward_layer_config'):
config['backward_layer'] = {
'class_name': self.backward_layer.__class__.__name__,
'config': self._backward_layer_config,
}
config['backward_layer'] = self._backward_layer_config
base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
# Instead of updating the input, create a copy and use that.
config = config.copy()
config = copy.deepcopy(config)
num_constants = config.pop('num_constants', 0)
# Handle forward layer instantiation (as would parent class).
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
config['layer'] = deserialize_layer(
config['layer'], custom_objects=custom_objects)
# Handle (optional) backward layer instantiation.
backward_layer_config = config.pop('backward_layer', None)
if backward_layer_config is not None:
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
backward_layer = deserialize_layer(
backward_layer_config, custom_objects=custom_objects)
config['backward_layer'] = backward_layer
layer = super(Bidirectional, cls).from_config(config,
custom_objects=custom_objects)
# Instantiate the wrapper, adjust it and return it.
layer = cls(**config)
layer._num_constants = num_constants
return layer