diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index 35202617716..d8daf79d2d5 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -45,6 +45,15 @@ if tf2.enabled(): from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top +# This deserialization table is added for backward compatibility, as in TF 1.13, +# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1 +# and v2 version of BatchNormalization, respectively. Here we explictly convert +# them to the canonical name in the config of deserialization. +_DESERIALIZATION_TABLE = { + 'BatchNormalizationV1': 'BatchNormalization', + 'BatchNormalizationV2': 'BatchNormalization', +} + @keras_export('keras.layers.serialize') def serialize(layer): @@ -68,6 +77,9 @@ def deserialize(config, custom_objects=None): globs['Network'] = models.Network globs['Model'] = models.Model globs['Sequential'] = models.Sequential + layer_class_name = config['class_name'] + if layer_class_name in _DESERIALIZATION_TABLE: + config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name] return deserialize_keras_object( config, diff --git a/tensorflow/python/keras/layers/serialization_test.py b/tensorflow/python/keras/layers/serialization_test.py index 5e9fa3cef8d..c0eb2fe9359 100644 --- a/tensorflow/python/keras/layers/serialization_test.py +++ b/tensorflow/python/keras/layers/serialization_test.py @@ -69,6 +69,31 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase): self.assertEqual(new_layer.gamma_regularizer.__class__, keras.regularizers.L1L2) + @parameterized.parameters( + [batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization]) + def test_deserialize_batchnorm_backwards_compatiblity(self, batchnorm_layer): + layer = batchnorm_layer( + momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') + config = keras.layers.serialize(layer) + # To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the + # saved model. + if batchnorm_layer is batchnorm_v1.BatchNormalization: + config['class_name'] = 'BatchNormalizationV1' + else: + config['class_name'] = 'BatchNormalizationV2' + new_layer = keras.layers.deserialize(config) + self.assertEqual(new_layer.momentum, 0.9) + if tf2.enabled(): + self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization) + self.assertEqual(new_layer.beta_initializer.__class__, + keras.initializers.ZerosV2) + else: + self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization) + self.assertEqual(new_layer.beta_initializer.__class__, + keras.initializers.Zeros) + self.assertEqual(new_layer.gamma_regularizer.__class__, + keras.regularizers.L1L2) + @parameterized.parameters([rnn_v1.LSTM, rnn_v2.LSTM]) def test_serialize_deserialize_lstm(self, layer): lstm = layer(5, return_sequences=True)