diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 66c68e20854..12c82a53f6e 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -670,6 +670,8 @@ class RNN(Layer): if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training + # TF RNN cells expect single tensor as state instead of list wrapped tensor. + is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') @@ -677,11 +679,21 @@ class RNN(Layer): def step(inputs, states): constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type - return self.cell.call(inputs, states, constants=constants, **kwargs) + + states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + output, new_states = self.cell.call( + inputs, states, constants=constants, **kwargs) + if not nest.is_sequence(new_states): + new_states = [new_states] + return output, new_states else: def step(inputs, states): - return self.cell.call(inputs, states, **kwargs) + states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + output, new_states = self.cell.call(inputs, states, **kwargs) + if not nest.is_sequence(new_states): + new_states = [new_states] + return output, new_states last_output, outputs, states = K.rnn( step, diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index e32d7c4e672..c72ada11dad 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -301,14 +301,12 @@ class RNNTest(test.TestCase): self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias)) def testRNNCellSerialization(self): - for cell in [ + for cell in [ rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True), rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32), - # TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras. - # rnn_cell_impl.BasicRNNCell( - # 32, activation="relu", dtype=dtypes.float32), - # rnn_cell_impl.GRUCell( - # 32, kernel_initializer="ones", dtype=dtypes.float32) + rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32), + rnn_cell_impl.GRUCell( + 32, kernel_initializer="ones", dtype=dtypes.float32) ]: with self.test_session(): x = keras.Input((None, 5)) @@ -326,11 +324,13 @@ class RNNTest(test.TestCase): # not visible as a Keras layer, and also has a name conflict with # keras.LSTMCell and GRUCell. layer = keras.layers.RNN.from_config( - config, custom_objects={ - # "BasicRNNCell": rnn_cell_impl.BasicRNNCell, - # "GRUCell": rnn_cell_impl.GRUCell, + config, + custom_objects={ + "BasicRNNCell": rnn_cell_impl.BasicRNNCell, + "GRUCell": rnn_cell_impl.GRUCell, "LSTMCell": rnn_cell_impl.LSTMCell, - "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell}) + "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell + }) y = layer(x) model = keras.models.Model(x, y) model.set_weights(weights) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index f481726d545..85a6a2233cc 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer): for each `s` in `self.batch_size`. """ + def __init__(self, trainable=True, name=None, dtype=None, **kwargs): + super(RNNCell, self).__init__( + trainable=trainable, name=name, dtype=dtype, **kwargs) + # Attribute that indicates whether the cell is a TF RNN cell, due the slight + # difference between TF and Keras RNN cell. + self._is_tf_rnn_cell = True + def __call__(self, inputs, state, scope=None): """Run this RNN cell on inputs, starting from the given state. @@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell): def get_config(self): config = { "num_units": self._num_units, - "initializer": initializers.serialize(self._initializer), "kernel_initializer": initializers.serialize(self._kernel_initializer), + "bias_initializer": initializers.serialize(self._bias_initializer), "activation": activations.serialize(self._activation), "reuse": self._reuse, }