diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index a90415220a2..11ac4494af0 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -2370,7 +2370,7 @@ class UnifiedGRU(DropoutRNNCellMixin, GRU): output = last_output if self.return_state: - return [output] + states + return [output] + list(states) elif self._return_runtime: return output, runtime else: @@ -3404,7 +3404,7 @@ class UnifiedLSTM(DropoutRNNCellMixin, LSTM): output = last_output if self.return_state: - return [output] + states + return [output] + list(states) elif self.return_runtime: return output, runtime else: diff --git a/tensorflow/python/keras/layers/unified_gru_test.py b/tensorflow/python/keras/layers/unified_gru_test.py index db861042380..3015b8bbfb0 100644 --- a/tensorflow/python/keras/layers/unified_gru_test.py +++ b/tensorflow/python/keras/layers/unified_gru_test.py @@ -335,6 +335,21 @@ class UnifiedGRUTest(keras_parameterized.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) + def test_return_states_GRU(self): + layer_class = keras.layers.UnifiedGRU + x = np.random.random((2, 3, 4)) + y = np.abs(np.random.random((2, 5))) + s = np.abs(np.random.random((2, 5))) + inputs = keras.layers.Input( + shape=[3, 4], dtype=dtypes.float32) + masked = keras.layers.Masking()(inputs) + outputs, states = layer_class(units=5, return_state=True)(masked) + + model = keras.models.Model(inputs, [outputs, states]) + model.compile(loss='categorical_crossentropy', + optimizer=gradient_descent.GradientDescentOptimizer(0.001)) + model.fit(x, [y, s], epochs=1, batch_size=2, verbose=1) + def test_dropout_GRU(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/layers/unified_lstm_test.py b/tensorflow/python/keras/layers/unified_lstm_test.py index 01089e1165f..316ce74d801 100644 --- a/tensorflow/python/keras/layers/unified_lstm_test.py +++ b/tensorflow/python/keras/layers/unified_lstm_test.py @@ -238,8 +238,9 @@ class UnifiedLSTMTest(keras_parameterized.TestCase): num_samples = 2 inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim)) + masked = keras.layers.Masking()(inputs) layer = keras.layers.UnifiedLSTM(units, return_state=True, stateful=True) - outputs = layer(inputs) + outputs = layer(masked) state = outputs[1:] assert len(state) == num_states model = keras.models.Model(inputs, state[0])