Undo changes to the input spec when RNN.unroll is True.
PiperOrigin-RevId: 298492355 Change-Id: I442ad2a23576cae8fa2fad02a27a199f13b159c1
This commit is contained in:
parent
76de54cc83
commit
81e2ecdaae
@ -534,8 +534,7 @@ class RNN(Layer):
|
|||||||
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
|
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
|
||||||
if not self.stateful:
|
if not self.stateful:
|
||||||
input_spec_shape[batch_index] = None
|
input_spec_shape[batch_index] = None
|
||||||
if not getattr(self, 'unroll', False):
|
input_spec_shape[time_step_index] = None
|
||||||
input_spec_shape[time_step_index] = None
|
|
||||||
return InputSpec(shape=tuple(input_spec_shape))
|
return InputSpec(shape=tuple(input_spec_shape))
|
||||||
|
|
||||||
def get_step_input_shape(shape):
|
def get_step_input_shape(shape):
|
||||||
|
@ -661,8 +661,11 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
|||||||
self.assertAllClose(loaded.predict_on_batch(array_ops.ones((3, 2, 1))),
|
self.assertAllClose(loaded.predict_on_batch(array_ops.ones((3, 2, 1))),
|
||||||
predictions)
|
predictions)
|
||||||
|
|
||||||
@parameterized.named_parameters([('with_unrolling', True),
|
@parameterized.named_parameters([
|
||||||
('no_unrolling', False)])
|
# TODO(b/148491963): Unrolling does not work with SavedModel
|
||||||
|
# ('with_unrolling', True),
|
||||||
|
('no_unrolling', False)
|
||||||
|
])
|
||||||
def testSaveStatefulRNN(self, unroll):
|
def testSaveStatefulRNN(self, unroll):
|
||||||
batch = 12
|
batch = 12
|
||||||
timesteps = 10
|
timesteps = 10
|
||||||
|
Loading…
x
Reference in New Issue
Block a user