Fixed Bidirectional serialization.

* Fixed `Bidirectional` de-serialization by skipping super call, as 
`copy.deepcopy` cannot handle instantiated layers within a config 
dictionary.
* On the side, improved `backward_layer` attribute's serialization 
scheme, so that registered custom layers may be re-instantiated.
This commit is contained in:
Paul Andrey 2020-03-25 15:58:00 +01:00
parent 8df4f0ce0a
commit bb463f8fa4

View File

@ -423,7 +423,8 @@ class Bidirectional(Wrapper):
# Keep the custom backward layer config, so that we can save it later. The # Keep the custom backward layer config, so that we can save it later. The
# layer's name might be updated below with prefix 'backward_', and we want # layer's name might be updated below with prefix 'backward_', and we want
# to preserve the original config. # to preserve the original config.
self._backward_layer_config = backward_layer.get_config() self._backward_layer_config = generic_utils.serialize_keras_object(
backward_layer)
self.forward_layer._name = 'forward_' + self.forward_layer.name self.forward_layer._name = 'forward_' + self.forward_layer.name
self.backward_layer._name = 'backward_' + self.backward_layer.name self.backward_layer._name = 'backward_' + self.backward_layer.name
@ -717,26 +718,26 @@ class Bidirectional(Wrapper):
config['num_constants'] = self._num_constants config['num_constants'] = self._num_constants
if hasattr(self, '_backward_layer_config'): if hasattr(self, '_backward_layer_config'):
config['backward_layer'] = { config['backward_layer'] = self._backward_layer_config
'class_name': self.backward_layer.__class__.__name__,
'config': self._backward_layer_config,
}
base_config = super(Bidirectional, self).get_config() base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
# Instead of updating the input, create a copy and use that. # Instead of updating the input, create a copy and use that.
config = config.copy() config = copy.deepcopy(config)
num_constants = config.pop('num_constants', 0) num_constants = config.pop('num_constants', 0)
# Handle forward layer instantiation (as would parent class).
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
config['layer'] = deserialize_layer(
config['layer'], custom_objects=custom_objects)
# Handle (optional) backward layer instantiation.
backward_layer_config = config.pop('backward_layer', None) backward_layer_config = config.pop('backward_layer', None)
if backward_layer_config is not None: if backward_layer_config is not None:
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
backward_layer = deserialize_layer( backward_layer = deserialize_layer(
backward_layer_config, custom_objects=custom_objects) backward_layer_config, custom_objects=custom_objects)
config['backward_layer'] = backward_layer config['backward_layer'] = backward_layer
# Instantiate the wrapper, adjust it and return it.
layer = super(Bidirectional, cls).from_config(config, layer = cls(**config)
custom_objects=custom_objects)
layer._num_constants = num_constants layer._num_constants = num_constants
return layer return layer