Add support for passive serialization in Keras. This lets users pass in any object that supports Keras serialization to any config, and ensures that the resulting to/from config is fully JSON-compatible.
PiperOrigin-RevId: 266159693
This commit is contained in:
parent
332acdd6a7
commit
a4fe7d8b69
@ -986,7 +986,8 @@ class Dense(Layer):
|
||||
|
||||
super(Dense, self).__init__(
|
||||
activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
|
||||
self.units = int(units)
|
||||
|
||||
self.units = units
|
||||
self.activation = activations.get(activation)
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
|
||||
@ -43,6 +43,7 @@ from tensorflow.python.keras.layers.recurrent import *
|
||||
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import *
|
||||
from tensorflow.python.keras.layers.wrappers import *
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
if tf2.enabled():
|
||||
@ -61,7 +62,7 @@ _DESERIALIZATION_TABLE = {
|
||||
|
||||
@keras_export('keras.layers.serialize')
|
||||
def serialize(layer):
|
||||
return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
|
||||
return serialize_keras_object(layer)
|
||||
|
||||
|
||||
@keras_export('keras.layers.deserialize')
|
||||
|
||||
@ -30,6 +30,19 @@ from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SerializableInt(int):
|
||||
|
||||
def __new__(cls, value):
|
||||
return int.__new__(cls, value)
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
@tf_test_util.run_all_in_graph_and_eager_modes
|
||||
class LayerSerializationTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@ -49,6 +62,42 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
|
||||
keras.initializers.Ones)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
|
||||
def test_implicit_serialize_deserialize_fails_without_object(self):
|
||||
layer = keras.layers.Dense(
|
||||
SerializableInt(3),
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
# Because we're passing an unknown class here, deserialization should fail
|
||||
# unless we add SerializableInt to the custom object dict.
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Unknown config_item: SerializableInt.*'):
|
||||
_ = keras.layers.deserialize(config)
|
||||
|
||||
def test_implicit_serialize_deserialize_succeeds_with_object(self):
|
||||
layer = keras.layers.Dense(
|
||||
SerializableInt(3),
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
# Because we're passing an unknown class here, deserialization should fail
|
||||
# unless we add SerializableInt to the custom object dict.
|
||||
new_layer = keras.layers.deserialize(
|
||||
config, custom_objects={'SerializableInt': SerializableInt})
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertEqual(new_layer.bias_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
if tf2.enabled():
|
||||
self.assertEqual(new_layer.kernel_initializer.__class__,
|
||||
keras.initializers.OnesV2)
|
||||
else:
|
||||
self.assertEqual(new_layer.kernel_initializer.__class__,
|
||||
keras.initializers.Ones)
|
||||
self.assertEqual(new_layer.units.__class__, SerializableInt)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
|
||||
@parameterized.parameters(
|
||||
[batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
|
||||
def test_serialize_deserialize_batchnorm(self, batchnorm_layer):
|
||||
|
||||
@ -194,8 +194,19 @@ def serialize_keras_object(instance):
|
||||
return None
|
||||
|
||||
if hasattr(instance, 'get_config'):
|
||||
config = instance.get_config()
|
||||
serialization_config = {}
|
||||
for key, item in config.items():
|
||||
try:
|
||||
serialized_item = serialize_keras_object(item)
|
||||
if isinstance(serialized_item, dict):
|
||||
serialized_item['__passive_serialization__'] = True
|
||||
serialization_config[key] = serialized_item
|
||||
except ValueError:
|
||||
serialization_config[key] = item
|
||||
|
||||
name = _get_name_or_custom_name(instance.__class__)
|
||||
return serialize_keras_class_and_config(name, instance.get_config())
|
||||
return serialize_keras_class_and_config(name, serialization_config)
|
||||
if hasattr(instance, '__name__'):
|
||||
return _get_name_or_custom_name(instance)
|
||||
raise ValueError('Cannot serialize', instance)
|
||||
@ -221,7 +232,21 @@ def class_and_config_for_serialized_keras_object(
|
||||
cls = module_objects.get(class_name)
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
return (cls, config['config'])
|
||||
|
||||
cls_config = config['config']
|
||||
deserialized_objects = {}
|
||||
for key, item in cls_config.items():
|
||||
if (isinstance(item, dict) and '__passive_serialization__' in item):
|
||||
deserialized_objects[key] = deserialize_keras_object(
|
||||
item,
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='config_item')
|
||||
|
||||
for key, item in deserialized_objects.items():
|
||||
cls_config[key] = deserialized_objects[key]
|
||||
|
||||
return (cls, cls_config)
|
||||
|
||||
|
||||
@keras_export('keras.utils.deserialize_keras_object')
|
||||
|
||||
@ -164,6 +164,85 @@ class SerializeKerasObjectTest(test.TestCase):
|
||||
def get_config(self):
|
||||
return {'value': self._value}
|
||||
|
||||
def test_serializable_object(self):
|
||||
|
||||
class SerializableInt(int):
|
||||
"""A serializable object to pass out of a test layer's config."""
|
||||
|
||||
def __new__(cls, value):
|
||||
return int.__new__(cls, value)
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
layer = keras.layers.Dense(
|
||||
SerializableInt(3),
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(
|
||||
config, custom_objects={'SerializableInt': SerializableInt})
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertEqual(new_layer.bias_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
self.assertEqual(new_layer.units.__class__, SerializableInt)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
|
||||
def test_nested_serializable_object(self):
|
||||
class SerializableInt(int):
|
||||
"""A serializable object to pass out of a test layer's config."""
|
||||
|
||||
def __new__(cls, value):
|
||||
return int.__new__(cls, value)
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
class SerializableNestedInt(int):
|
||||
"""A serializable object containing another serializable object."""
|
||||
|
||||
def __new__(cls, value, int_obj):
|
||||
obj = int.__new__(cls, value)
|
||||
obj.int_obj = int_obj
|
||||
return obj
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self), 'int_obj': self.int_obj}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
nested_int = SerializableInt(4)
|
||||
layer = keras.layers.Dense(
|
||||
SerializableNestedInt(3, nested_int),
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(
|
||||
config,
|
||||
custom_objects={
|
||||
'SerializableInt': SerializableInt,
|
||||
'SerializableNestedInt': SerializableNestedInt
|
||||
})
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertEqual(new_layer.bias_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
self.assertEqual(new_layer.units.__class__, SerializableNestedInt)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
self.assertEqual(new_layer.units.int_obj.__class__, SerializableInt)
|
||||
self.assertEqual(new_layer.units.int_obj, 4)
|
||||
|
||||
|
||||
class SliceArraysTest(test.TestCase):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user