From cc791f59c4027fb047469ea0854e450d23d8f1c1 Mon Sep 17 00:00:00 2001 From: feihugis Date: Tue, 7 Apr 2020 22:07:17 -0500 Subject: [PATCH 1/2] Make keras model load compatible with old version of models --- tensorflow/python/keras/utils/generic_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 27015cbc8f2..f86bc30bcde 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -296,6 +296,15 @@ def class_and_config_for_serialized_keras_object( raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) cls_config = config['config'] + # Check if `cls_config` is a list. If it is a list, return the class and the + # associated class configs for recursively deserialization. This case will + # happen on the old version of sequential model (e.g. `keras_version` == + # "2.0.6"), which is serialized in a different structure, for example + # "{'class_name': 'Sequential', + # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". + if isinstance(cls_config, list): + return (cls, cls_config) + deserialized_objects = {} for key, item in cls_config.items(): if isinstance(item, dict) and '__passive_serialization__' in item: From 49b07e664590f73942bcbaf1c378e59ccab9f04b Mon Sep 17 00:00:00 2001 From: feihugis Date: Mon, 13 Apr 2020 22:59:16 -0500 Subject: [PATCH 2/2] Add the test case --- .../python/keras/utils/generic_utils_test.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py index 334758871fa..3850591e453 100644 --- a/tensorflow/python/keras/utils/generic_utils_test.py +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python import keras from tensorflow.python.platform import test @@ -298,6 +300,53 @@ class SerializeKerasObjectTest(test.TestCase): self.assertEqual(new_layer.units, 3) self.assertIs(new_layer.units.fn, serializable_fn) + def test_serializable_with_old_config(self): + # model config generated by tf-1.2.1 + old_model_config = { + 'class_name': 'Sequential', + 'config': [ + { + 'class_name': 'Dense', + 'config': { + 'name': 'dense_1', + 'trainable': True, + 'batch_input_shape': [ + None, + 784 + ], + 'dtype': 'float32', + 'units': 32, + 'activation': 'linear', + 'use_bias': True, + 'kernel_initializer': { + 'class_name': 'Ones', + 'config': { + 'dtype': 'float32' + } + }, + 'bias_initializer': { + 'class_name': 'Zeros', + 'config': { + 'dtype': 'float32' + } + }, + 'kernel_regularizer': None, + 'bias_regularizer': None, + 'activity_regularizer': None, + 'kernel_constraint': None, + 'bias_constraint': None + } + } + ] + } + old_model = keras.utils.generic_utils.deserialize_keras_object( + old_model_config, module_objects={'Sequential': keras.Sequential}) + new_model = keras.Sequential( + [keras.layers.Dense(32, input_dim=784, kernel_initializer='Ones'),]) + input_data = np.random.normal(2, 1, (5, 784)) + output = old_model.predict(input_data) + expected_output = new_model.predict(input_data) + self.assertAllEqual(output, expected_output) class SliceArraysTest(test.TestCase):