To make BatchNormalization layer backward-compatible.

In TF 1.13, BatchNormalization layer is named as BatchNormalizationV1 for v1, and BatchNormalizationV2 for v2 version. So BatchNormalizationV1 and BatchNormalizationV1 appear in saved model. This CL explictily converts v1 and v2 version of names to its canonical name for backward compatibility.

PiperOrigin-RevId: 239304519
This commit is contained in:
Yanhui Liang 2019-03-19 17:25:04 -07:00 committed by TensorFlower Gardener
parent f67c35d358
commit 18493fb2a3
2 changed files with 37 additions and 0 deletions

View File

@ -45,6 +45,15 @@ if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
# This deserialization table is added for backward compatibility, as in TF 1.13,
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1
# and v2 version of BatchNormalization, respectively. Here we explictly convert
# them to the canonical name in the config of deserialization.
_DESERIALIZATION_TABLE = {
'BatchNormalizationV1': 'BatchNormalization',
'BatchNormalizationV2': 'BatchNormalization',
}
@keras_export('keras.layers.serialize') @keras_export('keras.layers.serialize')
def serialize(layer): def serialize(layer):
@ -68,6 +77,9 @@ def deserialize(config, custom_objects=None):
globs['Network'] = models.Network globs['Network'] = models.Network
globs['Model'] = models.Model globs['Model'] = models.Model
globs['Sequential'] = models.Sequential globs['Sequential'] = models.Sequential
layer_class_name = config['class_name']
if layer_class_name in _DESERIALIZATION_TABLE:
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
return deserialize_keras_object( return deserialize_keras_object(
config, config,

View File

@ -69,6 +69,31 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
self.assertEqual(new_layer.gamma_regularizer.__class__, self.assertEqual(new_layer.gamma_regularizer.__class__,
keras.regularizers.L1L2) keras.regularizers.L1L2)
@parameterized.parameters(
[batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
def test_deserialize_batchnorm_backwards_compatiblity(self, batchnorm_layer):
layer = batchnorm_layer(
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
config = keras.layers.serialize(layer)
# To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the
# saved model.
if batchnorm_layer is batchnorm_v1.BatchNormalization:
config['class_name'] = 'BatchNormalizationV1'
else:
config['class_name'] = 'BatchNormalizationV2'
new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.momentum, 0.9)
if tf2.enabled():
self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization)
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.ZerosV2)
else:
self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization)
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.Zeros)
self.assertEqual(new_layer.gamma_regularizer.__class__,
keras.regularizers.L1L2)
@parameterized.parameters([rnn_v1.LSTM, rnn_v2.LSTM]) @parameterized.parameters([rnn_v1.LSTM, rnn_v2.LSTM])
def test_serialize_deserialize_lstm(self, layer): def test_serialize_deserialize_lstm(self, layer):
lstm = layer(5, return_sequences=True) lstm = layer(5, return_sequences=True)