Fix save model issue for stateless ConvLSTM2D layer.
The root cause is that ConvLSTM2D.state is a tuple rather than a list. When converting the state for save_model, the tuple is not converted to trackable objects since the states are (None, None). On the other hand, save_model requires all objects to be trackable when saving. We didn't hit this issue for keras.LSTM since its state is a list, rather than tuple. The list is auto convert to ListWrapper since list itself is mutable. This should fix https://github.com/tensorflow/tensorflow/issues/40328 and partly https://github.com/tensorflow/tensorflow/issues/38220 PiperOrigin-RevId: 317131403 Change-Id: I202d4dbdb29accc7a047d5f5a2fef08d24d05c7c
This commit is contained in:
parent
81041bcd82
commit
47582983cb
|
@ -159,6 +159,12 @@ class RNNSavedModelSaver(LayerSavedModelSaver):
|
|||
objects, functions = (
|
||||
super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
|
||||
serialization_cache))
|
||||
|
||||
objects['states'] = data_structures.wrap_or_unwrap(self.obj.states)
|
||||
states = data_structures.wrap_or_unwrap(self.obj.states)
|
||||
# Force the tuple into TupleWrapper which is a trackable object. The
|
||||
# save/load code requires all the objects to be trackable.
|
||||
# Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap()
|
||||
# if it doesn't contains any trackable objects.
|
||||
if isinstance(states, tuple):
|
||||
states = data_structures._TupleWrapper(states) # pylint: disable=protected-access
|
||||
objects['states'] = states
|
||||
return objects, functions
|
||||
|
|
|
@ -773,6 +773,26 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
|||
self.assertAllClose(layer.states, loaded_layer.states)
|
||||
self.assertAllClose(model(input_arr), loaded(input_arr))
|
||||
|
||||
def testSaveStatelessConvLSTM2D(self):
|
||||
data_format = 'channels_first'
|
||||
batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4
|
||||
input_arr = np.ones(
|
||||
(batch, timesteps, channels, rows, cols)).astype('float32')
|
||||
layer = keras.layers.ConvLSTM2D(
|
||||
filters=16, kernel_size=(1, 1), data_format=data_format)
|
||||
x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols))
|
||||
y = layer(x)
|
||||
model = keras.Model(x, y)
|
||||
|
||||
predict_1 = model(input_arr)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
del model
|
||||
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
predict_2 = loaded(input_arr)
|
||||
self.assertAllClose(predict_1, predict_2)
|
||||
|
||||
def testSaveWithRaggedInputs(self):
|
||||
|
||||
class EmbeddingMerger(keras.layers.Layer):
|
||||
|
|
Loading…
Reference in New Issue