Fix save model issue for stateless ConvLSTM2D layer.
The root cause is that ConvLSTM2D.state is a tuple rather than a list. When converting the state for save_model, the tuple is not converted to trackable objects since the states are (None, None). On the other hand, save_model requires all objects to be trackable when saving. We didn't hit this issue for keras.LSTM since its state is a list, rather than tuple. The list is auto convert to ListWrapper since list itself is mutable. This should fix https://github.com/tensorflow/tensorflow/issues/40328 and partly https://github.com/tensorflow/tensorflow/issues/38220 PiperOrigin-RevId: 317131403 Change-Id: I202d4dbdb29accc7a047d5f5a2fef08d24d05c7c
This commit is contained in:
parent
81041bcd82
commit
47582983cb
|
@ -159,6 +159,12 @@ class RNNSavedModelSaver(LayerSavedModelSaver):
|
||||||
objects, functions = (
|
objects, functions = (
|
||||||
super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
|
super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
|
||||||
serialization_cache))
|
serialization_cache))
|
||||||
|
states = data_structures.wrap_or_unwrap(self.obj.states)
|
||||||
objects['states'] = data_structures.wrap_or_unwrap(self.obj.states)
|
# Force the tuple into TupleWrapper which is a trackable object. The
|
||||||
|
# save/load code requires all the objects to be trackable.
|
||||||
|
# Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap()
|
||||||
|
# if it doesn't contains any trackable objects.
|
||||||
|
if isinstance(states, tuple):
|
||||||
|
states = data_structures._TupleWrapper(states) # pylint: disable=protected-access
|
||||||
|
objects['states'] = states
|
||||||
return objects, functions
|
return objects, functions
|
||||||
|
|
|
@ -773,6 +773,26 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||||
self.assertAllClose(layer.states, loaded_layer.states)
|
self.assertAllClose(layer.states, loaded_layer.states)
|
||||||
self.assertAllClose(model(input_arr), loaded(input_arr))
|
self.assertAllClose(model(input_arr), loaded(input_arr))
|
||||||
|
|
||||||
|
def testSaveStatelessConvLSTM2D(self):
|
||||||
|
data_format = 'channels_first'
|
||||||
|
batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4
|
||||||
|
input_arr = np.ones(
|
||||||
|
(batch, timesteps, channels, rows, cols)).astype('float32')
|
||||||
|
layer = keras.layers.ConvLSTM2D(
|
||||||
|
filters=16, kernel_size=(1, 1), data_format=data_format)
|
||||||
|
x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols))
|
||||||
|
y = layer(x)
|
||||||
|
model = keras.Model(x, y)
|
||||||
|
|
||||||
|
predict_1 = model(input_arr)
|
||||||
|
saved_model_dir = self._save_model_dir()
|
||||||
|
tf_save.save(model, saved_model_dir)
|
||||||
|
del model
|
||||||
|
|
||||||
|
loaded = keras_load.load(saved_model_dir)
|
||||||
|
predict_2 = loaded(input_arr)
|
||||||
|
self.assertAllClose(predict_1, predict_2)
|
||||||
|
|
||||||
def testSaveWithRaggedInputs(self):
|
def testSaveWithRaggedInputs(self):
|
||||||
|
|
||||||
class EmbeddingMerger(keras.layers.Layer):
|
class EmbeddingMerger(keras.layers.Layer):
|
||||||
|
|
Loading…
Reference in New Issue