diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index a7510f0d1f1..01b801330df 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -418,8 +418,8 @@ class Bidirectional(Wrapper): 'Merge mode should be one of ' '{"sum", "mul", "ave", "concat", None}') if getattr(layer, 'zero_output_for_mask', None) is not None: - # Force the zero_output_for_mask to be True if it presents. - layer.zero_output_for_mask = True + # Force the zero_output_for_mask to be True if returning sequences. + layer.zero_output_for_mask = layer.return_sequences self.forward_layer = copy.copy(layer) config = layer.get_config() diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index fec89382063..78f95b31f13 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -688,7 +688,33 @@ class BidirectionalTest(test.TestCase): y_np_3 = model.predict([x_np, s_fw_np, s_bk_np, c_np]) self.assertAllClose(y_np, y_np_3, atol=1e-4) - def test_Bidirectional_with_masking(self): + def test_Bidirectional_last_output_with_masking(self): + rnn = keras.layers.LSTM + samples = 2 + dim = 5 + timesteps = 3 + units = 3 + merge_mode = 'concat' + x = np.random.rand(samples, timesteps, dim) + # clear the first record's timestep 2. Last output should be same as state, + # not zeroed. + x[0, 2] = 0 + + with self.cached_session(): + inputs = keras.Input((timesteps, dim)) + masked_inputs = keras.layers.Masking()(inputs) + wrapped = keras.layers.Bidirectional( + rnn(units, return_state=True), merge_mode=merge_mode) + outputs = _to_list(wrapped(masked_inputs, training=True)) + self.assertEqual(len(outputs), 5) + self.assertEqual(outputs[0].get_shape().as_list(), [None, units * 2]) + + model = keras.Model(inputs, outputs) + y = _to_list(model.predict(x)) + self.assertEqual(len(y), 5) + self.assertAllClose(y[0], np.concatenate([y[1], y[3]], axis=1)) + + def test_Bidirectional_sequence_output_with_masking(self): rnn = keras.layers.LSTM samples = 2 dim = 5