Fixed Wrapper
's get_config
and from_config
.
* `get_config`: properly serialize the wrapped layer. This notably fixes issues when wrapping custom layers that have been registered using `tf.keras.utils.register_keras_serializable`. * `from_config`: properly copy input config to avoid side effects.
This commit is contained in:
parent
8b6cad7adc
commit
8df4f0ce0a
@ -68,10 +68,7 @@ class Wrapper(Layer):
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'layer': {
|
||||
'class_name': self.layer.__class__.__name__,
|
||||
'config': self.layer.get_config()
|
||||
}
|
||||
'layer': generic_utils.serialize_keras_object(self.layer)
|
||||
}
|
||||
base_config = super(Wrapper, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
@ -80,7 +77,7 @@ class Wrapper(Layer):
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||
# Avoid mutating the input dict
|
||||
config = config.copy()
|
||||
config = copy.deepcopy(config)
|
||||
layer = deserialize_layer(
|
||||
config.pop('layer'), custom_objects=custom_objects)
|
||||
return cls(layer, **config)
|
||||
|
Loading…
Reference in New Issue
Block a user