Set zero_output_for_mask in tf.keras.layers.Bidirectional only when return_sequences=True.

Previously, Bidirectional would always force zero_output_for_mask=True. However, this made it difficult to get just the last output from the layer when using masked sequences - if return_sequences=False, the last output for the forward RNN would be zeros.

PiperOrigin-RevId: 236009388
This commit is contained in:
A. Unique TensorFlower 2019-02-27 15:53:04 -08:00 committed by TensorFlower Gardener
parent 4728deaf57
commit 97e36fc65e
2 changed files with 29 additions and 3 deletions

View File

@ -418,8 +418,8 @@ class Bidirectional(Wrapper):
'Merge mode should be one of '
'{"sum", "mul", "ave", "concat", None}')
if getattr(layer, 'zero_output_for_mask', None) is not None:
# Force the zero_output_for_mask to be True if it presents.
layer.zero_output_for_mask = True
# Force the zero_output_for_mask to be True if returning sequences.
layer.zero_output_for_mask = layer.return_sequences
self.forward_layer = copy.copy(layer)
config = layer.get_config()

View File

@ -688,7 +688,33 @@ class BidirectionalTest(test.TestCase):
y_np_3 = model.predict([x_np, s_fw_np, s_bk_np, c_np])
self.assertAllClose(y_np, y_np_3, atol=1e-4)
def test_Bidirectional_with_masking(self):
def test_Bidirectional_last_output_with_masking(self):
rnn = keras.layers.LSTM
samples = 2
dim = 5
timesteps = 3
units = 3
merge_mode = 'concat'
x = np.random.rand(samples, timesteps, dim)
# clear the first record's timestep 2. Last output should be same as state,
# not zeroed.
x[0, 2] = 0
with self.cached_session():
inputs = keras.Input((timesteps, dim))
masked_inputs = keras.layers.Masking()(inputs)
wrapped = keras.layers.Bidirectional(
rnn(units, return_state=True), merge_mode=merge_mode)
outputs = _to_list(wrapped(masked_inputs, training=True))
self.assertEqual(len(outputs), 5)
self.assertEqual(outputs[0].get_shape().as_list(), [None, units * 2])
model = keras.Model(inputs, outputs)
y = _to_list(model.predict(x))
self.assertEqual(len(y), 5)
self.assertAllClose(y[0], np.concatenate([y[1], y[3]], axis=1))
def test_Bidirectional_sequence_output_with_masking(self):
rnn = keras.layers.LSTM
samples = 2
dim = 5