Merge pull request #37578 from pandrey-fr:patch-1
PiperOrigin-RevId: 303363479 Change-Id: I22cedaffc8b38b3f925c76694c5d688c88a5a41a
This commit is contained in:
commit
8c65545d87
@ -67,12 +67,7 @@ class Wrapper(Layer):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {'layer': generic_utils.serialize_keras_object(self.layer)}
|
||||||
'layer': {
|
|
||||||
'class_name': self.layer.__class__.__name__,
|
|
||||||
'config': self.layer.get_config()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
base_config = super(Wrapper, self).get_config()
|
base_config = super(Wrapper, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@ -80,7 +75,7 @@ class Wrapper(Layer):
|
|||||||
def from_config(cls, config, custom_objects=None):
|
def from_config(cls, config, custom_objects=None):
|
||||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||||
# Avoid mutating the input dict
|
# Avoid mutating the input dict
|
||||||
config = config.copy()
|
config = copy.deepcopy(config)
|
||||||
layer = deserialize_layer(
|
layer = deserialize_layer(
|
||||||
config.pop('layer'), custom_objects=custom_objects)
|
config.pop('layer'), custom_objects=custom_objects)
|
||||||
return cls(layer, **config)
|
return cls(layer, **config)
|
||||||
@ -426,7 +421,8 @@ class Bidirectional(Wrapper):
|
|||||||
# Keep the custom backward layer config, so that we can save it later. The
|
# Keep the custom backward layer config, so that we can save it later. The
|
||||||
# layer's name might be updated below with prefix 'backward_', and we want
|
# layer's name might be updated below with prefix 'backward_', and we want
|
||||||
# to preserve the original config.
|
# to preserve the original config.
|
||||||
self._backward_layer_config = backward_layer.get_config()
|
self._backward_layer_config = generic_utils.serialize_keras_object(
|
||||||
|
backward_layer)
|
||||||
|
|
||||||
self.forward_layer._name = 'forward_' + self.forward_layer.name
|
self.forward_layer._name = 'forward_' + self.forward_layer.name
|
||||||
self.backward_layer._name = 'backward_' + self.backward_layer.name
|
self.backward_layer._name = 'backward_' + self.backward_layer.name
|
||||||
@ -720,26 +716,26 @@ class Bidirectional(Wrapper):
|
|||||||
config['num_constants'] = self._num_constants
|
config['num_constants'] = self._num_constants
|
||||||
|
|
||||||
if hasattr(self, '_backward_layer_config'):
|
if hasattr(self, '_backward_layer_config'):
|
||||||
config['backward_layer'] = {
|
config['backward_layer'] = self._backward_layer_config
|
||||||
'class_name': self.backward_layer.__class__.__name__,
|
|
||||||
'config': self._backward_layer_config,
|
|
||||||
}
|
|
||||||
base_config = super(Bidirectional, self).get_config()
|
base_config = super(Bidirectional, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config, custom_objects=None):
|
def from_config(cls, config, custom_objects=None):
|
||||||
# Instead of updating the input, create a copy and use that.
|
# Instead of updating the input, create a copy and use that.
|
||||||
config = config.copy()
|
config = copy.deepcopy(config)
|
||||||
num_constants = config.pop('num_constants', 0)
|
num_constants = config.pop('num_constants', 0)
|
||||||
|
# Handle forward layer instantiation (as would parent class).
|
||||||
|
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||||
|
config['layer'] = deserialize_layer(
|
||||||
|
config['layer'], custom_objects=custom_objects)
|
||||||
|
# Handle (optional) backward layer instantiation.
|
||||||
backward_layer_config = config.pop('backward_layer', None)
|
backward_layer_config = config.pop('backward_layer', None)
|
||||||
if backward_layer_config is not None:
|
if backward_layer_config is not None:
|
||||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
|
||||||
backward_layer = deserialize_layer(
|
backward_layer = deserialize_layer(
|
||||||
backward_layer_config, custom_objects=custom_objects)
|
backward_layer_config, custom_objects=custom_objects)
|
||||||
config['backward_layer'] = backward_layer
|
config['backward_layer'] = backward_layer
|
||||||
|
# Instantiate the wrapper, adjust it and return it.
|
||||||
layer = super(Bidirectional, cls).from_config(config,
|
layer = cls(**config)
|
||||||
custom_objects=custom_objects)
|
|
||||||
layer._num_constants = num_constants
|
layer._num_constants = num_constants
|
||||||
return layer
|
return layer
|
||||||
|
Loading…
Reference in New Issue
Block a user