diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index 559b6158d87..4216457bf28 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -159,6 +159,12 @@ class RNNSavedModelSaver(LayerSavedModelSaver): objects, functions = ( super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( serialization_cache)) - - objects['states'] = data_structures.wrap_or_unwrap(self.obj.states) + states = data_structures.wrap_or_unwrap(self.obj.states) + # Force the tuple into TupleWrapper which is a trackable object. The + # save/load code requires all the objects to be trackable. + # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap() + # if it doesn't contains any trackable objects. + if isinstance(states, tuple): + states = data_structures._TupleWrapper(states) # pylint: disable=protected-access + objects['states'] = states return objects, functions diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 8d4d27e2357..3f55d5f40b5 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -773,6 +773,26 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllClose(layer.states, loaded_layer.states) self.assertAllClose(model(input_arr), loaded(input_arr)) + def testSaveStatelessConvLSTM2D(self): + data_format = 'channels_first' + batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4 + input_arr = np.ones( + (batch, timesteps, channels, rows, cols)).astype('float32') + layer = keras.layers.ConvLSTM2D( + filters=16, kernel_size=(1, 1), data_format=data_format) + x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols)) + y = layer(x) + model = keras.Model(x, y) + + predict_1 = model(input_arr) + saved_model_dir = self._save_model_dir() + tf_save.save(model, saved_model_dir) + del model + + loaded = keras_load.load(saved_model_dir) + predict_2 = loaded(input_arr) + self.assertAllClose(predict_1, predict_2) + def testSaveWithRaggedInputs(self): class EmbeddingMerger(keras.layers.Layer):