Fixed Bidirectional
serialization.
* Fixed `Bidirectional` de-serialization by skipping super call, as `copy.deepcopy` cannot handle instantiated layers within a config dictionary. * On the side, improved `backward_layer` attribute's serialization scheme, so that registered custom layers may be re-instantiated.
This commit is contained in:
parent
8df4f0ce0a
commit
bb463f8fa4
@ -423,7 +423,8 @@ class Bidirectional(Wrapper):
|
||||
# Keep the custom backward layer config, so that we can save it later. The
|
||||
# layer's name might be updated below with prefix 'backward_', and we want
|
||||
# to preserve the original config.
|
||||
self._backward_layer_config = backward_layer.get_config()
|
||||
self._backward_layer_config = generic_utils.serialize_keras_object(
|
||||
backward_layer)
|
||||
|
||||
self.forward_layer._name = 'forward_' + self.forward_layer.name
|
||||
self.backward_layer._name = 'backward_' + self.backward_layer.name
|
||||
@ -717,26 +718,26 @@ class Bidirectional(Wrapper):
|
||||
config['num_constants'] = self._num_constants
|
||||
|
||||
if hasattr(self, '_backward_layer_config'):
|
||||
config['backward_layer'] = {
|
||||
'class_name': self.backward_layer.__class__.__name__,
|
||||
'config': self._backward_layer_config,
|
||||
}
|
||||
config['backward_layer'] = self._backward_layer_config
|
||||
base_config = super(Bidirectional, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
# Instead of updating the input, create a copy and use that.
|
||||
config = config.copy()
|
||||
config = copy.deepcopy(config)
|
||||
num_constants = config.pop('num_constants', 0)
|
||||
# Handle forward layer instantiation (as would parent class).
|
||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||
config['layer'] = deserialize_layer(
|
||||
config['layer'], custom_objects=custom_objects)
|
||||
# Handle (optional) backward layer instantiation.
|
||||
backward_layer_config = config.pop('backward_layer', None)
|
||||
if backward_layer_config is not None:
|
||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||
backward_layer = deserialize_layer(
|
||||
backward_layer_config, custom_objects=custom_objects)
|
||||
config['backward_layer'] = backward_layer
|
||||
|
||||
layer = super(Bidirectional, cls).from_config(config,
|
||||
custom_objects=custom_objects)
|
||||
# Instantiate the wrapper, adjust it and return it.
|
||||
layer = cls(**config)
|
||||
layer._num_constants = num_constants
|
||||
return layer
|
||||
|
Loading…
Reference in New Issue
Block a user