[RNN] Update TF GRU and BasicRNN cell to support Keras RNN layer.

Keras RNN layer will feed in state_t as a list, and expect state_t+1 as list as well, even the state is a single tensor for simple cell. On the other hand, TF RNN does not wrap single tensor state.

PiperOrigin-RevId: 208666836
This commit is contained in:
Scott Zhu 2018-08-14 10:13:53 -07:00 committed by TensorFlower Gardener
parent 4be575a6c2
commit 87e5662309
3 changed files with 32 additions and 13 deletions

View File

@ -670,6 +670,8 @@ class RNN(Layer):
if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
@ -677,11 +679,21 @@ class RNN(Layer):
def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
return self.cell.call(inputs, states, constants=constants, **kwargs)
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(
inputs, states, constants=constants, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
else:
def step(inputs, states):
return self.cell.call(inputs, states, **kwargs)
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(inputs, states, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
last_output, outputs, states = K.rnn(
step,

View File

@ -301,14 +301,12 @@ class RNNTest(test.TestCase):
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
def testRNNCellSerialization(self):
for cell in [
for cell in [
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
# TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras.
# rnn_cell_impl.BasicRNNCell(
# 32, activation="relu", dtype=dtypes.float32),
# rnn_cell_impl.GRUCell(
# 32, kernel_initializer="ones", dtype=dtypes.float32)
rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
rnn_cell_impl.GRUCell(
32, kernel_initializer="ones", dtype=dtypes.float32)
]:
with self.test_session():
x = keras.Input((None, 5))
@ -326,11 +324,13 @@ class RNNTest(test.TestCase):
# not visible as a Keras layer, and also has a name conflict with
# keras.LSTMCell and GRUCell.
layer = keras.layers.RNN.from_config(
config, custom_objects={
# "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
# "GRUCell": rnn_cell_impl.GRUCell,
config,
custom_objects={
"BasicRNNCell": rnn_cell_impl.BasicRNNCell,
"GRUCell": rnn_cell_impl.GRUCell,
"LSTMCell": rnn_cell_impl.LSTMCell,
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell})
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
})
y = layer(x)
model = keras.models.Model(x, y)
model.set_weights(weights)

View File

@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer):
for each `s` in `self.batch_size`.
"""
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
super(RNNCell, self).__init__(
trainable=trainable, name=name, dtype=dtype, **kwargs)
# Attribute that indicates whether the cell is a TF RNN cell, due the slight
# difference between TF and Keras RNN cell.
self._is_tf_rnn_cell = True
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell):
def get_config(self):
config = {
"num_units": self._num_units,
"initializer": initializers.serialize(self._initializer),
"kernel_initializer": initializers.serialize(self._kernel_initializer),
"bias_initializer": initializers.serialize(self._bias_initializer),
"activation": activations.serialize(self._activation),
"reuse": self._reuse,
}