Merge pull request #38339 from feihugis:issue_class_and_config_for_serialized_keras_object
PiperOrigin-RevId: 314589293 Change-Id: I42a04d7d27764aba10bd935e5cd80ab97d2833a8
This commit is contained in:
commit
e19e2d29d5
|
@ -296,6 +296,15 @@ def class_and_config_for_serialized_keras_object(
|
||||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||||
|
|
||||||
cls_config = config['config']
|
cls_config = config['config']
|
||||||
|
# Check if `cls_config` is a list. If it is a list, return the class and the
|
||||||
|
# associated class configs for recursively deserialization. This case will
|
||||||
|
# happen on the old version of sequential model (e.g. `keras_version` ==
|
||||||
|
# "2.0.6"), which is serialized in a different structure, for example
|
||||||
|
# "{'class_name': 'Sequential',
|
||||||
|
# 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
|
||||||
|
if isinstance(cls_config, list):
|
||||||
|
return (cls, cls_config)
|
||||||
|
|
||||||
deserialized_objects = {}
|
deserialized_objects = {}
|
||||||
for key, item in cls_config.items():
|
for key, item in cls_config.items():
|
||||||
if isinstance(item, dict) and '__passive_serialization__' in item:
|
if isinstance(item, dict) and '__passive_serialization__' in item:
|
||||||
|
|
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
@ -298,6 +300,51 @@ class SerializeKerasObjectTest(test.TestCase):
|
||||||
self.assertEqual(new_layer.units, 3)
|
self.assertEqual(new_layer.units, 3)
|
||||||
self.assertIs(new_layer.units.fn, serializable_fn)
|
self.assertIs(new_layer.units.fn, serializable_fn)
|
||||||
|
|
||||||
|
def test_serializable_with_old_config(self):
|
||||||
|
# model config generated by tf-1.2.1
|
||||||
|
old_model_config = {
|
||||||
|
'class_name':
|
||||||
|
'Sequential',
|
||||||
|
'config': [{
|
||||||
|
'class_name': 'Dense',
|
||||||
|
'config': {
|
||||||
|
'name': 'dense_1',
|
||||||
|
'trainable': True,
|
||||||
|
'batch_input_shape': [None, 784],
|
||||||
|
'dtype': 'float32',
|
||||||
|
'units': 32,
|
||||||
|
'activation': 'linear',
|
||||||
|
'use_bias': True,
|
||||||
|
'kernel_initializer': {
|
||||||
|
'class_name': 'Ones',
|
||||||
|
'config': {
|
||||||
|
'dtype': 'float32'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'bias_initializer': {
|
||||||
|
'class_name': 'Zeros',
|
||||||
|
'config': {
|
||||||
|
'dtype': 'float32'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'kernel_regularizer': None,
|
||||||
|
'bias_regularizer': None,
|
||||||
|
'activity_regularizer': None,
|
||||||
|
'kernel_constraint': None,
|
||||||
|
'bias_constraint': None
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
old_model = keras.utils.generic_utils.deserialize_keras_object(
|
||||||
|
old_model_config, module_objects={'Sequential': keras.Sequential})
|
||||||
|
new_model = keras.Sequential([
|
||||||
|
keras.layers.Dense(32, input_dim=784, kernel_initializer='Ones'),
|
||||||
|
])
|
||||||
|
input_data = np.random.normal(2, 1, (5, 784))
|
||||||
|
output = old_model.predict(input_data)
|
||||||
|
expected_output = new_model.predict(input_data)
|
||||||
|
self.assertAllEqual(output, expected_output)
|
||||||
|
|
||||||
|
|
||||||
class SliceArraysTest(test.TestCase):
|
class SliceArraysTest(test.TestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue