Fix the new LSTM/GRU breakage when return_state=True with masking.
See originial issue in #26026. Thanks to the original reporter huan@. PiperOrigin-RevId: 236165775
This commit is contained in:
parent
5139b72eff
commit
d51261962b
@ -2370,7 +2370,7 @@ class UnifiedGRU(DropoutRNNCellMixin, GRU):
|
||||
output = last_output
|
||||
|
||||
if self.return_state:
|
||||
return [output] + states
|
||||
return [output] + list(states)
|
||||
elif self._return_runtime:
|
||||
return output, runtime
|
||||
else:
|
||||
@ -3404,7 +3404,7 @@ class UnifiedLSTM(DropoutRNNCellMixin, LSTM):
|
||||
output = last_output
|
||||
|
||||
if self.return_state:
|
||||
return [output] + states
|
||||
return [output] + list(states)
|
||||
elif self.return_runtime:
|
||||
return output, runtime
|
||||
else:
|
||||
|
@ -335,6 +335,21 @@ class UnifiedGRUTest(keras_parameterized.TestCase):
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_return_states_GRU(self):
|
||||
layer_class = keras.layers.UnifiedGRU
|
||||
x = np.random.random((2, 3, 4))
|
||||
y = np.abs(np.random.random((2, 5)))
|
||||
s = np.abs(np.random.random((2, 5)))
|
||||
inputs = keras.layers.Input(
|
||||
shape=[3, 4], dtype=dtypes.float32)
|
||||
masked = keras.layers.Masking()(inputs)
|
||||
outputs, states = layer_class(units=5, return_state=True)(masked)
|
||||
|
||||
model = keras.models.Model(inputs, [outputs, states])
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=gradient_descent.GradientDescentOptimizer(0.001))
|
||||
model.fit(x, [y, s], epochs=1, batch_size=2, verbose=1)
|
||||
|
||||
def test_dropout_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
|
@ -238,8 +238,9 @@ class UnifiedLSTMTest(keras_parameterized.TestCase):
|
||||
num_samples = 2
|
||||
|
||||
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
|
||||
masked = keras.layers.Masking()(inputs)
|
||||
layer = keras.layers.UnifiedLSTM(units, return_state=True, stateful=True)
|
||||
outputs = layer(inputs)
|
||||
outputs = layer(masked)
|
||||
state = outputs[1:]
|
||||
assert len(state) == num_states
|
||||
model = keras.models.Model(inputs, state[0])
|
||||
|
Loading…
Reference in New Issue
Block a user