Merge pull request #37578 from pandrey-fr:patch-1

PiperOrigin-RevId: 303363479
Change-Id: I22cedaffc8b38b3f925c76694c5d688c88a5a41a
This commit is contained in:
TensorFlower Gardener 2020-03-27 10:57:47 -07:00
commit 8c65545d87

View File

@ -67,12 +67,7 @@ class Wrapper(Layer):
return None return None
def get_config(self): def get_config(self):
config = { config = {'layer': generic_utils.serialize_keras_object(self.layer)}
'layer': {
'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()
}
}
base_config = super(Wrapper, self).get_config() base_config = super(Wrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -80,7 +75,7 @@ class Wrapper(Layer):
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
# Avoid mutating the input dict # Avoid mutating the input dict
config = config.copy() config = copy.deepcopy(config)
layer = deserialize_layer( layer = deserialize_layer(
config.pop('layer'), custom_objects=custom_objects) config.pop('layer'), custom_objects=custom_objects)
return cls(layer, **config) return cls(layer, **config)
@ -426,7 +421,8 @@ class Bidirectional(Wrapper):
# Keep the custom backward layer config, so that we can save it later. The # Keep the custom backward layer config, so that we can save it later. The
# layer's name might be updated below with prefix 'backward_', and we want # layer's name might be updated below with prefix 'backward_', and we want
# to preserve the original config. # to preserve the original config.
self._backward_layer_config = backward_layer.get_config() self._backward_layer_config = generic_utils.serialize_keras_object(
backward_layer)
self.forward_layer._name = 'forward_' + self.forward_layer.name self.forward_layer._name = 'forward_' + self.forward_layer.name
self.backward_layer._name = 'backward_' + self.backward_layer.name self.backward_layer._name = 'backward_' + self.backward_layer.name
@ -720,26 +716,26 @@ class Bidirectional(Wrapper):
config['num_constants'] = self._num_constants config['num_constants'] = self._num_constants
if hasattr(self, '_backward_layer_config'): if hasattr(self, '_backward_layer_config'):
config['backward_layer'] = { config['backward_layer'] = self._backward_layer_config
'class_name': self.backward_layer.__class__.__name__,
'config': self._backward_layer_config,
}
base_config = super(Bidirectional, self).get_config() base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
# Instead of updating the input, create a copy and use that. # Instead of updating the input, create a copy and use that.
config = config.copy() config = copy.deepcopy(config)
num_constants = config.pop('num_constants', 0) num_constants = config.pop('num_constants', 0)
# Handle forward layer instantiation (as would parent class).
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
config['layer'] = deserialize_layer(
config['layer'], custom_objects=custom_objects)
# Handle (optional) backward layer instantiation.
backward_layer_config = config.pop('backward_layer', None) backward_layer_config = config.pop('backward_layer', None)
if backward_layer_config is not None: if backward_layer_config is not None:
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
backward_layer = deserialize_layer( backward_layer = deserialize_layer(
backward_layer_config, custom_objects=custom_objects) backward_layer_config, custom_objects=custom_objects)
config['backward_layer'] = backward_layer config['backward_layer'] = backward_layer
# Instantiate the wrapper, adjust it and return it.
layer = super(Bidirectional, cls).from_config(config, layer = cls(**config)
custom_objects=custom_objects)
layer._num_constants = num_constants layer._num_constants = num_constants
return layer return layer