Undo changes to the input spec when RNN.unroll is True.

PiperOrigin-RevId: 298492355
Change-Id: I442ad2a23576cae8fa2fad02a27a199f13b159c1
This commit is contained in:
Katherine Wu 2020-03-02 17:53:53 -08:00 committed by TensorFlower Gardener
parent 76de54cc83
commit 81e2ecdaae
2 changed files with 6 additions and 4 deletions

View File

@ -534,8 +534,7 @@ class RNN(Layer):
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
if not self.stateful:
input_spec_shape[batch_index] = None
if not getattr(self, 'unroll', False):
input_spec_shape[time_step_index] = None
input_spec_shape[time_step_index] = None
return InputSpec(shape=tuple(input_spec_shape))
def get_step_input_shape(shape):

View File

@ -661,8 +661,11 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.assertAllClose(loaded.predict_on_batch(array_ops.ones((3, 2, 1))),
predictions)
@parameterized.named_parameters([('with_unrolling', True),
('no_unrolling', False)])
@parameterized.named_parameters([
# TODO(b/148491963): Unrolling does not work with SavedModel
# ('with_unrolling', True),
('no_unrolling', False)
])
def testSaveStatefulRNN(self, unroll):
batch = 12
timesteps = 10