From 49b07e664590f73942bcbaf1c378e59ccab9f04b Mon Sep 17 00:00:00 2001 From: feihugis Date: Mon, 13 Apr 2020 22:59:16 -0500 Subject: [PATCH] Add the test case --- .../python/keras/utils/generic_utils_test.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py index 334758871fa..3850591e453 100644 --- a/tensorflow/python/keras/utils/generic_utils_test.py +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python import keras from tensorflow.python.platform import test @@ -298,6 +300,53 @@ class SerializeKerasObjectTest(test.TestCase): self.assertEqual(new_layer.units, 3) self.assertIs(new_layer.units.fn, serializable_fn) + def test_serializable_with_old_config(self): + # model config generated by tf-1.2.1 + old_model_config = { + 'class_name': 'Sequential', + 'config': [ + { + 'class_name': 'Dense', + 'config': { + 'name': 'dense_1', + 'trainable': True, + 'batch_input_shape': [ + None, + 784 + ], + 'dtype': 'float32', + 'units': 32, + 'activation': 'linear', + 'use_bias': True, + 'kernel_initializer': { + 'class_name': 'Ones', + 'config': { + 'dtype': 'float32' + } + }, + 'bias_initializer': { + 'class_name': 'Zeros', + 'config': { + 'dtype': 'float32' + } + }, + 'kernel_regularizer': None, + 'bias_regularizer': None, + 'activity_regularizer': None, + 'kernel_constraint': None, + 'bias_constraint': None + } + } + ] + } + old_model = keras.utils.generic_utils.deserialize_keras_object( + old_model_config, module_objects={'Sequential': keras.Sequential}) + new_model = keras.Sequential( + [keras.layers.Dense(32, input_dim=784, kernel_initializer='Ones'),]) + input_data = np.random.normal(2, 1, (5, 784)) + output = old_model.predict(input_data) + expected_output = new_model.predict(input_data) + self.assertAllEqual(output, expected_output) class SliceArraysTest(test.TestCase):