To make BatchNormalization layer backward-compatible.
In TF 1.13, BatchNormalization layer is named as BatchNormalizationV1 for v1, and BatchNormalizationV2 for v2 version. So BatchNormalizationV1 and BatchNormalizationV1 appear in saved model. This CL explictily converts v1 and v2 version of names to its canonical name for backward compatibility. PiperOrigin-RevId: 239304519
This commit is contained in:
parent
f67c35d358
commit
18493fb2a3
@ -45,6 +45,15 @@ if tf2.enabled():
|
||||
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
|
||||
|
||||
# This deserialization table is added for backward compatibility, as in TF 1.13,
|
||||
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1
|
||||
# and v2 version of BatchNormalization, respectively. Here we explictly convert
|
||||
# them to the canonical name in the config of deserialization.
|
||||
_DESERIALIZATION_TABLE = {
|
||||
'BatchNormalizationV1': 'BatchNormalization',
|
||||
'BatchNormalizationV2': 'BatchNormalization',
|
||||
}
|
||||
|
||||
|
||||
@keras_export('keras.layers.serialize')
|
||||
def serialize(layer):
|
||||
@ -68,6 +77,9 @@ def deserialize(config, custom_objects=None):
|
||||
globs['Network'] = models.Network
|
||||
globs['Model'] = models.Model
|
||||
globs['Sequential'] = models.Sequential
|
||||
layer_class_name = config['class_name']
|
||||
if layer_class_name in _DESERIALIZATION_TABLE:
|
||||
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
|
||||
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
|
@ -69,6 +69,31 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
|
||||
self.assertEqual(new_layer.gamma_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
|
||||
@parameterized.parameters(
|
||||
[batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
|
||||
def test_deserialize_batchnorm_backwards_compatiblity(self, batchnorm_layer):
|
||||
layer = batchnorm_layer(
|
||||
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
# To simulate if BatchNormalizationV1 or BatchNormalizationV2 appears in the
|
||||
# saved model.
|
||||
if batchnorm_layer is batchnorm_v1.BatchNormalization:
|
||||
config['class_name'] = 'BatchNormalizationV1'
|
||||
else:
|
||||
config['class_name'] = 'BatchNormalizationV2'
|
||||
new_layer = keras.layers.deserialize(config)
|
||||
self.assertEqual(new_layer.momentum, 0.9)
|
||||
if tf2.enabled():
|
||||
self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization)
|
||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||
keras.initializers.ZerosV2)
|
||||
else:
|
||||
self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization)
|
||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||
keras.initializers.Zeros)
|
||||
self.assertEqual(new_layer.gamma_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
|
||||
@parameterized.parameters([rnn_v1.LSTM, rnn_v2.LSTM])
|
||||
def test_serialize_deserialize_lstm(self, layer):
|
||||
lstm = layer(5, return_sequences=True)
|
||||
|
Loading…
Reference in New Issue
Block a user