Set zero_output_for_mask in tf.keras.layers.Bidirectional only when return_sequences=True.
Previously, Bidirectional would always force zero_output_for_mask=True. However, this made it difficult to get just the last output from the layer when using masked sequences - if return_sequences=False, the last output for the forward RNN would be zeros. PiperOrigin-RevId: 236009388
This commit is contained in:
parent
4728deaf57
commit
97e36fc65e
@ -418,8 +418,8 @@ class Bidirectional(Wrapper):
|
||||
'Merge mode should be one of '
|
||||
'{"sum", "mul", "ave", "concat", None}')
|
||||
if getattr(layer, 'zero_output_for_mask', None) is not None:
|
||||
# Force the zero_output_for_mask to be True if it presents.
|
||||
layer.zero_output_for_mask = True
|
||||
# Force the zero_output_for_mask to be True if returning sequences.
|
||||
layer.zero_output_for_mask = layer.return_sequences
|
||||
|
||||
self.forward_layer = copy.copy(layer)
|
||||
config = layer.get_config()
|
||||
|
@ -688,7 +688,33 @@ class BidirectionalTest(test.TestCase):
|
||||
y_np_3 = model.predict([x_np, s_fw_np, s_bk_np, c_np])
|
||||
self.assertAllClose(y_np, y_np_3, atol=1e-4)
|
||||
|
||||
def test_Bidirectional_with_masking(self):
|
||||
def test_Bidirectional_last_output_with_masking(self):
|
||||
rnn = keras.layers.LSTM
|
||||
samples = 2
|
||||
dim = 5
|
||||
timesteps = 3
|
||||
units = 3
|
||||
merge_mode = 'concat'
|
||||
x = np.random.rand(samples, timesteps, dim)
|
||||
# clear the first record's timestep 2. Last output should be same as state,
|
||||
# not zeroed.
|
||||
x[0, 2] = 0
|
||||
|
||||
with self.cached_session():
|
||||
inputs = keras.Input((timesteps, dim))
|
||||
masked_inputs = keras.layers.Masking()(inputs)
|
||||
wrapped = keras.layers.Bidirectional(
|
||||
rnn(units, return_state=True), merge_mode=merge_mode)
|
||||
outputs = _to_list(wrapped(masked_inputs, training=True))
|
||||
self.assertEqual(len(outputs), 5)
|
||||
self.assertEqual(outputs[0].get_shape().as_list(), [None, units * 2])
|
||||
|
||||
model = keras.Model(inputs, outputs)
|
||||
y = _to_list(model.predict(x))
|
||||
self.assertEqual(len(y), 5)
|
||||
self.assertAllClose(y[0], np.concatenate([y[1], y[3]], axis=1))
|
||||
|
||||
def test_Bidirectional_sequence_output_with_masking(self):
|
||||
rnn = keras.layers.LSTM
|
||||
samples = 2
|
||||
dim = 5
|
||||
|
Loading…
Reference in New Issue
Block a user