[RNN] Update TF GRU and BasicRNN cell to support Keras RNN layer.

Keras RNN layer will feed in state_t as a list, and expect state_t+1 as list as well, even the state is a single tensor for simple cell. On the other hand, TF RNN does not wrap single tensor state.

PiperOrigin-RevId: 208666836
This commit is contained in:
Scott Zhu 2018-08-14 10:13:53 -07:00 committed by TensorFlower Gardener
parent 4be575a6c2
commit 87e5662309
3 changed files with 32 additions and 13 deletions

View File

@ -670,6 +670,8 @@ class RNN(Layer):
if generic_utils.has_arg(self.cell.call, 'training'): if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training kwargs['training'] = training
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
if constants: if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'): if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants') raise ValueError('RNN cell does not support constants')
@ -677,11 +679,21 @@ class RNN(Layer):
def step(inputs, states): def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
return self.cell.call(inputs, states, constants=constants, **kwargs)
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(
inputs, states, constants=constants, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
else: else:
def step(inputs, states): def step(inputs, states):
return self.cell.call(inputs, states, **kwargs) states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(inputs, states, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
last_output, outputs, states = K.rnn( last_output, outputs, states = K.rnn(
step, step,

View File

@ -304,11 +304,9 @@ class RNNTest(test.TestCase):
for cell in [ for cell in [
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True), rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32), rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
# TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras. rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
# rnn_cell_impl.BasicRNNCell( rnn_cell_impl.GRUCell(
# 32, activation="relu", dtype=dtypes.float32), 32, kernel_initializer="ones", dtype=dtypes.float32)
# rnn_cell_impl.GRUCell(
# 32, kernel_initializer="ones", dtype=dtypes.float32)
]: ]:
with self.test_session(): with self.test_session():
x = keras.Input((None, 5)) x = keras.Input((None, 5))
@ -326,11 +324,13 @@ class RNNTest(test.TestCase):
# not visible as a Keras layer, and also has a name conflict with # not visible as a Keras layer, and also has a name conflict with
# keras.LSTMCell and GRUCell. # keras.LSTMCell and GRUCell.
layer = keras.layers.RNN.from_config( layer = keras.layers.RNN.from_config(
config, custom_objects={ config,
# "BasicRNNCell": rnn_cell_impl.BasicRNNCell, custom_objects={
# "GRUCell": rnn_cell_impl.GRUCell, "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
"GRUCell": rnn_cell_impl.GRUCell,
"LSTMCell": rnn_cell_impl.LSTMCell, "LSTMCell": rnn_cell_impl.LSTMCell,
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell}) "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
})
y = layer(x) y = layer(x)
model = keras.models.Model(x, y) model = keras.models.Model(x, y)
model.set_weights(weights) model.set_weights(weights)

View File

@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer):
for each `s` in `self.batch_size`. for each `s` in `self.batch_size`.
""" """
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
super(RNNCell, self).__init__(
trainable=trainable, name=name, dtype=dtype, **kwargs)
# Attribute that indicates whether the cell is a TF RNN cell, due the slight
# difference between TF and Keras RNN cell.
self._is_tf_rnn_cell = True
def __call__(self, inputs, state, scope=None): def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state. """Run this RNN cell on inputs, starting from the given state.
@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell):
def get_config(self): def get_config(self):
config = { config = {
"num_units": self._num_units, "num_units": self._num_units,
"initializer": initializers.serialize(self._initializer),
"kernel_initializer": initializers.serialize(self._kernel_initializer), "kernel_initializer": initializers.serialize(self._kernel_initializer),
"bias_initializer": initializers.serialize(self._bias_initializer),
"activation": activations.serialize(self._activation), "activation": activations.serialize(self._activation),
"reuse": self._reuse, "reuse": self._reuse,
} }