Add support for passive serialization in Keras. This lets users pass in any object that supports Keras serialization to any config, and ensures that the resulting to/from config is fully JSON-compatible.

PiperOrigin-RevId: 266159693
This commit is contained in:
A. Unique TensorFlower 2019-08-29 09:33:18 -07:00 committed by TensorFlower Gardener
parent 332acdd6a7
commit a4fe7d8b69
5 changed files with 159 additions and 4 deletions

View File

@ -986,7 +986,8 @@ class Dense(Layer):
super(Dense, self).__init__(
activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
self.units = int(units)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)

View File

@ -43,6 +43,7 @@ from tensorflow.python.keras.layers.recurrent import *
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import *
from tensorflow.python.keras.layers.wrappers import *
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.util.tf_export import keras_export
if tf2.enabled():
@ -61,7 +62,7 @@ _DESERIALIZATION_TABLE = {
@keras_export('keras.layers.serialize')
def serialize(layer):
return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
return serialize_keras_object(layer)
@keras_export('keras.layers.deserialize')

View File

@ -30,6 +30,19 @@ from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
from tensorflow.python.platform import test
class SerializableInt(int):
def __new__(cls, value):
return int.__new__(cls, value)
def get_config(self):
return {'value': int(self)}
@classmethod
def from_config(cls, config):
return cls(**config)
@tf_test_util.run_all_in_graph_and_eager_modes
class LayerSerializationTest(parameterized.TestCase, test.TestCase):
@ -49,6 +62,42 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
keras.initializers.Ones)
self.assertEqual(new_layer.units, 3)
def test_implicit_serialize_deserialize_fails_without_object(self):
layer = keras.layers.Dense(
SerializableInt(3),
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
# Because we're passing an unknown class here, deserialization should fail
# unless we add SerializableInt to the custom object dict.
with self.assertRaisesRegex(ValueError,
'Unknown config_item: SerializableInt.*'):
_ = keras.layers.deserialize(config)
def test_implicit_serialize_deserialize_succeeds_with_object(self):
layer = keras.layers.Dense(
SerializableInt(3),
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
# Because we're passing an unknown class here, deserialization should fail
# unless we add SerializableInt to the custom object dict.
new_layer = keras.layers.deserialize(
config, custom_objects={'SerializableInt': SerializableInt})
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertEqual(new_layer.bias_regularizer.__class__,
keras.regularizers.L1L2)
if tf2.enabled():
self.assertEqual(new_layer.kernel_initializer.__class__,
keras.initializers.OnesV2)
else:
self.assertEqual(new_layer.kernel_initializer.__class__,
keras.initializers.Ones)
self.assertEqual(new_layer.units.__class__, SerializableInt)
self.assertEqual(new_layer.units, 3)
@parameterized.parameters(
[batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
def test_serialize_deserialize_batchnorm(self, batchnorm_layer):

View File

@ -194,8 +194,19 @@ def serialize_keras_object(instance):
return None
if hasattr(instance, 'get_config'):
config = instance.get_config()
serialization_config = {}
for key, item in config.items():
try:
serialized_item = serialize_keras_object(item)
if isinstance(serialized_item, dict):
serialized_item['__passive_serialization__'] = True
serialization_config[key] = serialized_item
except ValueError:
serialization_config[key] = item
name = _get_name_or_custom_name(instance.__class__)
return serialize_keras_class_and_config(name, instance.get_config())
return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'):
return _get_name_or_custom_name(instance)
raise ValueError('Cannot serialize', instance)
@ -221,7 +232,21 @@ def class_and_config_for_serialized_keras_object(
cls = module_objects.get(class_name)
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
return (cls, config['config'])
cls_config = config['config']
deserialized_objects = {}
for key, item in cls_config.items():
if (isinstance(item, dict) and '__passive_serialization__' in item):
deserialized_objects[key] = deserialize_keras_object(
item,
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='config_item')
for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key]
return (cls, cls_config)
@keras_export('keras.utils.deserialize_keras_object')

View File

@ -164,6 +164,85 @@ class SerializeKerasObjectTest(test.TestCase):
def get_config(self):
return {'value': self._value}
def test_serializable_object(self):
class SerializableInt(int):
"""A serializable object to pass out of a test layer's config."""
def __new__(cls, value):
return int.__new__(cls, value)
def get_config(self):
return {'value': int(self)}
@classmethod
def from_config(cls, config):
return cls(**config)
layer = keras.layers.Dense(
SerializableInt(3),
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
config, custom_objects={'SerializableInt': SerializableInt})
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertEqual(new_layer.bias_regularizer.__class__,
keras.regularizers.L1L2)
self.assertEqual(new_layer.units.__class__, SerializableInt)
self.assertEqual(new_layer.units, 3)
def test_nested_serializable_object(self):
class SerializableInt(int):
"""A serializable object to pass out of a test layer's config."""
def __new__(cls, value):
return int.__new__(cls, value)
def get_config(self):
return {'value': int(self)}
@classmethod
def from_config(cls, config):
return cls(**config)
class SerializableNestedInt(int):
"""A serializable object containing another serializable object."""
def __new__(cls, value, int_obj):
obj = int.__new__(cls, value)
obj.int_obj = int_obj
return obj
def get_config(self):
return {'value': int(self), 'int_obj': self.int_obj}
@classmethod
def from_config(cls, config):
return cls(**config)
nested_int = SerializableInt(4)
layer = keras.layers.Dense(
SerializableNestedInt(3, nested_int),
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
config,
custom_objects={
'SerializableInt': SerializableInt,
'SerializableNestedInt': SerializableNestedInt
})
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertEqual(new_layer.bias_regularizer.__class__,
keras.regularizers.L1L2)
self.assertEqual(new_layer.units.__class__, SerializableNestedInt)
self.assertEqual(new_layer.units, 3)
self.assertEqual(new_layer.units.int_obj.__class__, SerializableInt)
self.assertEqual(new_layer.units.int_obj, 4)
class SliceArraysTest(test.TestCase):