From 47582983cb1064b5bb81233db4f0adeeaa10b74d Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 18 Jun 2020 10:30:56 -0700 Subject: [PATCH] Fix save model issue for stateless ConvLSTM2D layer. The root cause is that ConvLSTM2D.state is a tuple rather than a list. When converting the state for save_model, the tuple is not converted to trackable objects since the states are (None, None). On the other hand, save_model requires all objects to be trackable when saving. We didn't hit this issue for keras.LSTM since its state is a list, rather than tuple. The list is auto convert to ListWrapper since list itself is mutable. This should fix https://github.com/tensorflow/tensorflow/issues/40328 and partly https://github.com/tensorflow/tensorflow/issues/38220 PiperOrigin-RevId: 317131403 Change-Id: I202d4dbdb29accc7a047d5f5a2fef08d24d05c7c --- .../saving/saved_model/layer_serialization.py | 10 ++++++++-- .../saving/saved_model/saved_model_test.py | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index 559b6158d87..4216457bf28 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -159,6 +159,12 @@ class RNNSavedModelSaver(LayerSavedModelSaver): objects, functions = ( super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( serialization_cache)) - - objects['states'] = data_structures.wrap_or_unwrap(self.obj.states) + states = data_structures.wrap_or_unwrap(self.obj.states) + # Force the tuple into TupleWrapper which is a trackable object. The + # save/load code requires all the objects to be trackable. + # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap() + # if it doesn't contains any trackable objects. + if isinstance(states, tuple): + states = data_structures._TupleWrapper(states) # pylint: disable=protected-access + objects['states'] = states return objects, functions diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 8d4d27e2357..3f55d5f40b5 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -773,6 +773,26 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllClose(layer.states, loaded_layer.states) self.assertAllClose(model(input_arr), loaded(input_arr)) + def testSaveStatelessConvLSTM2D(self): + data_format = 'channels_first' + batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4 + input_arr = np.ones( + (batch, timesteps, channels, rows, cols)).astype('float32') + layer = keras.layers.ConvLSTM2D( + filters=16, kernel_size=(1, 1), data_format=data_format) + x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols)) + y = layer(x) + model = keras.Model(x, y) + + predict_1 = model(input_arr) + saved_model_dir = self._save_model_dir() + tf_save.save(model, saved_model_dir) + del model + + loaded = keras_load.load(saved_model_dir) + predict_2 = loaded(input_arr) + self.assertAllClose(predict_1, predict_2) + def testSaveWithRaggedInputs(self): class EmbeddingMerger(keras.layers.Layer):