Undo changes to the input spec when RNN.unroll is True.
PiperOrigin-RevId: 298492355 Change-Id: I442ad2a23576cae8fa2fad02a27a199f13b159c1
This commit is contained in:
parent
76de54cc83
commit
81e2ecdaae
@ -534,8 +534,7 @@ class RNN(Layer):
|
||||
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
|
||||
if not self.stateful:
|
||||
input_spec_shape[batch_index] = None
|
||||
if not getattr(self, 'unroll', False):
|
||||
input_spec_shape[time_step_index] = None
|
||||
input_spec_shape[time_step_index] = None
|
||||
return InputSpec(shape=tuple(input_spec_shape))
|
||||
|
||||
def get_step_input_shape(shape):
|
||||
|
@ -661,8 +661,11 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.assertAllClose(loaded.predict_on_batch(array_ops.ones((3, 2, 1))),
|
||||
predictions)
|
||||
|
||||
@parameterized.named_parameters([('with_unrolling', True),
|
||||
('no_unrolling', False)])
|
||||
@parameterized.named_parameters([
|
||||
# TODO(b/148491963): Unrolling does not work with SavedModel
|
||||
# ('with_unrolling', True),
|
||||
('no_unrolling', False)
|
||||
])
|
||||
def testSaveStatefulRNN(self, unroll):
|
||||
batch = 12
|
||||
timesteps = 10
|
||||
|
Loading…
Reference in New Issue
Block a user