From 18493fb2a3d83d3cd4e7cfcdbf11d314de18dadd Mon Sep 17 00:00:00 2001 From: Yanhui Liang Date: Tue, 19 Mar 2019 17:25:04 -0700 Subject: [PATCH] To make BatchNormalization layer backward-compatible. In TF 1.13, BatchNormalization layer is named as BatchNormalizationV1 for v1, and BatchNormalizationV2 for v2 version. So BatchNormalizationV1 and BatchNormalizationV1 appear in saved model. This CL explictily converts v1 and v2 version of names to its canonical name for backward compatibility. PiperOrigin-RevId: 239304519 --- .../python/keras/layers/serialization.py | 12 +++++++++ .../python/keras/layers/serialization_test.py | 25 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index 35202617716..d8daf79d2d5 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -45,6 +45,15 @@ if tf2.enabled(): from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top +# This deserialization table is added for backward compatibility, as in TF 1.13, +# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1 +# and v2 version of BatchNormalization, respectively. Here we explictly convert +# them to the canonical name in the config of deserialization. +_DESERIALIZATION_TABLE = { + 'BatchNormalizationV1': 'BatchNormalization', + 'BatchNormalizationV2': 'BatchNormalization', +} + @keras_export('keras.layers.serialize') def serialize(layer): @@ -68,6 +77,9 @@ def deserialize(config, custom_objects=None): globs['Network'] = models.Network globs['Model'] = models.Model globs['Sequential'] = models.Sequential + layer_class_name = config['class_name'] + if layer_class_name in _DESERIALIZATION_TABLE: + config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name] return deserialize_keras_object( config, diff --git a/tensorflow/python/keras/layers/serialization_test.py b/tensorflow/python/keras/layers/serialization_test.py index 5e9fa3cef8d..c0eb2fe9359 100644 --- a/tensorflow/python/keras/layers/serialization_test.py +++ b/tensorflow/python/keras/layers/serialization_test.py @@ -69,6 +69,31 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase): self.assertEqual(new_layer.gamma_regularizer.__class__, keras.regularizers.L1L2) + @parameterized.parameters( + [batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization]) + def test_deserialize_batchnorm_backwards_compatiblity(self, batchnorm_layer): + layer = batchnorm_layer( + momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') + config = keras.layers.serialize(layer) + # To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the + # saved model. + if batchnorm_layer is batchnorm_v1.BatchNormalization: + config['class_name'] = 'BatchNormalizationV1' + else: + config['class_name'] = 'BatchNormalizationV2' + new_layer = keras.layers.deserialize(config) + self.assertEqual(new_layer.momentum, 0.9) + if tf2.enabled(): + self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization) + self.assertEqual(new_layer.beta_initializer.__class__, + keras.initializers.ZerosV2) + else: + self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization) + self.assertEqual(new_layer.beta_initializer.__class__, + keras.initializers.Zeros) + self.assertEqual(new_layer.gamma_regularizer.__class__, + keras.regularizers.L1L2) + @parameterized.parameters([rnn_v1.LSTM, rnn_v2.LSTM]) def test_serialize_deserialize_lstm(self, layer): lstm = layer(5, return_sequences=True)