[RNN] Update TF GRU and BasicRNN cell to support Keras RNN layer.
Keras RNN layer will feed in state_t as a list, and expect state_t+1 as list as well, even the state is a single tensor for simple cell. On the other hand, TF RNN does not wrap single tensor state. PiperOrigin-RevId: 208666836
This commit is contained in:
parent
4be575a6c2
commit
87e5662309
@ -670,6 +670,8 @@ class RNN(Layer):
|
||||
if generic_utils.has_arg(self.cell.call, 'training'):
|
||||
kwargs['training'] = training
|
||||
|
||||
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
|
||||
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
|
||||
if constants:
|
||||
if not generic_utils.has_arg(self.cell.call, 'constants'):
|
||||
raise ValueError('RNN cell does not support constants')
|
||||
@ -677,11 +679,21 @@ class RNN(Layer):
|
||||
def step(inputs, states):
|
||||
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
|
||||
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
|
||||
return self.cell.call(inputs, states, constants=constants, **kwargs)
|
||||
|
||||
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||
output, new_states = self.cell.call(
|
||||
inputs, states, constants=constants, **kwargs)
|
||||
if not nest.is_sequence(new_states):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
else:
|
||||
|
||||
def step(inputs, states):
|
||||
return self.cell.call(inputs, states, **kwargs)
|
||||
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||
output, new_states = self.cell.call(inputs, states, **kwargs)
|
||||
if not nest.is_sequence(new_states):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
|
||||
last_output, outputs, states = K.rnn(
|
||||
step,
|
||||
|
@ -301,14 +301,12 @@ class RNNTest(test.TestCase):
|
||||
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
|
||||
|
||||
def testRNNCellSerialization(self):
|
||||
for cell in [
|
||||
for cell in [
|
||||
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
|
||||
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
|
||||
# TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras.
|
||||
# rnn_cell_impl.BasicRNNCell(
|
||||
# 32, activation="relu", dtype=dtypes.float32),
|
||||
# rnn_cell_impl.GRUCell(
|
||||
# 32, kernel_initializer="ones", dtype=dtypes.float32)
|
||||
rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
|
||||
rnn_cell_impl.GRUCell(
|
||||
32, kernel_initializer="ones", dtype=dtypes.float32)
|
||||
]:
|
||||
with self.test_session():
|
||||
x = keras.Input((None, 5))
|
||||
@ -326,11 +324,13 @@ class RNNTest(test.TestCase):
|
||||
# not visible as a Keras layer, and also has a name conflict with
|
||||
# keras.LSTMCell and GRUCell.
|
||||
layer = keras.layers.RNN.from_config(
|
||||
config, custom_objects={
|
||||
# "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
|
||||
# "GRUCell": rnn_cell_impl.GRUCell,
|
||||
config,
|
||||
custom_objects={
|
||||
"BasicRNNCell": rnn_cell_impl.BasicRNNCell,
|
||||
"GRUCell": rnn_cell_impl.GRUCell,
|
||||
"LSTMCell": rnn_cell_impl.LSTMCell,
|
||||
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell})
|
||||
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
|
||||
})
|
||||
y = layer(x)
|
||||
model = keras.models.Model(x, y)
|
||||
model.set_weights(weights)
|
||||
|
@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer):
|
||||
for each `s` in `self.batch_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
|
||||
super(RNNCell, self).__init__(
|
||||
trainable=trainable, name=name, dtype=dtype, **kwargs)
|
||||
# Attribute that indicates whether the cell is a TF RNN cell, due the slight
|
||||
# difference between TF and Keras RNN cell.
|
||||
self._is_tf_rnn_cell = True
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Run this RNN cell on inputs, starting from the given state.
|
||||
|
||||
@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell):
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_units": self._num_units,
|
||||
"initializer": initializers.serialize(self._initializer),
|
||||
"kernel_initializer": initializers.serialize(self._kernel_initializer),
|
||||
"bias_initializer": initializers.serialize(self._bias_initializer),
|
||||
"activation": activations.serialize(self._activation),
|
||||
"reuse": self._reuse,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user