Fix save model issue for stateless ConvLSTM2D layer.

The root cause is that ConvLSTM2D.state is a tuple rather than a list. When converting the state for save_model, the tuple is not converted to trackable objects since the states are (None, None). On the other hand, save_model requires all objects to be trackable when saving.

We didn't hit this issue for keras.LSTM since its state is a list, rather than tuple. The list is auto convert to ListWrapper since list itself is mutable.

This should fix https://github.com/tensorflow/tensorflow/issues/40328 and partly https://github.com/tensorflow/tensorflow/issues/38220

PiperOrigin-RevId: 317131403
Change-Id: I202d4dbdb29accc7a047d5f5a2fef08d24d05c7c
This commit is contained in:
Scott Zhu 2020-06-18 10:30:56 -07:00 committed by TensorFlower Gardener
parent 81041bcd82
commit 47582983cb
2 changed files with 28 additions and 2 deletions

View File

@ -159,6 +159,12 @@ class RNNSavedModelSaver(LayerSavedModelSaver):
objects, functions = (
super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
serialization_cache))
objects['states'] = data_structures.wrap_or_unwrap(self.obj.states)
states = data_structures.wrap_or_unwrap(self.obj.states)
# Force the tuple into TupleWrapper which is a trackable object. The
# save/load code requires all the objects to be trackable.
# Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap()
# if it doesn't contains any trackable objects.
if isinstance(states, tuple):
states = data_structures._TupleWrapper(states) # pylint: disable=protected-access
objects['states'] = states
return objects, functions

View File

@ -773,6 +773,26 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(layer.states, loaded_layer.states)
self.assertAllClose(model(input_arr), loaded(input_arr))
def testSaveStatelessConvLSTM2D(self):
data_format = 'channels_first'
batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4
input_arr = np.ones(
(batch, timesteps, channels, rows, cols)).astype('float32')
layer = keras.layers.ConvLSTM2D(
filters=16, kernel_size=(1, 1), data_format=data_format)
x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols))
y = layer(x)
model = keras.Model(x, y)
predict_1 = model(input_arr)
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
del model
loaded = keras_load.load(saved_model_dir)
predict_2 = loaded(input_arr)
self.assertAllClose(predict_1, predict_2)
def testSaveWithRaggedInputs(self):
class EmbeddingMerger(keras.layers.Layer):