Fixed Wrapper's get_config and from_config.

* `get_config`: properly serialize the wrapped layer.
  This notably fixes issues when wrapping custom layers that have
  been registered using `tf.keras.utils.register_keras_serializable`.
* `from_config`: properly copy input config to avoid side effects.
This commit is contained in:
Paul Andrey 2020-03-13 16:15:59 +01:00 committed by GitHub
parent 8b6cad7adc
commit 8df4f0ce0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -68,10 +68,7 @@ class Wrapper(Layer):
def get_config(self): def get_config(self):
config = { config = {
'layer': { 'layer': generic_utils.serialize_keras_object(self.layer)
'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()
}
} }
base_config = super(Wrapper, self).get_config() base_config = super(Wrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
@ -80,7 +77,7 @@ class Wrapper(Layer):
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
# Avoid mutating the input dict # Avoid mutating the input dict
config = config.copy() config = copy.deepcopy(config)
layer = deserialize_layer( layer = deserialize_layer(
config.pop('layer'), custom_objects=custom_objects) config.pop('layer'), custom_objects=custom_objects)
return cls(layer, **config) return cls(layer, **config)