[RNN] Update TF GRU and BasicRNN cell to support Keras RNN layer.
Keras RNN layer will feed in state_t as a list, and expect state_t+1 as list as well, even the state is a single tensor for simple cell. On the other hand, TF RNN does not wrap single tensor state. PiperOrigin-RevId: 208666836
This commit is contained in:
parent
4be575a6c2
commit
87e5662309
@ -670,6 +670,8 @@ class RNN(Layer):
|
|||||||
if generic_utils.has_arg(self.cell.call, 'training'):
|
if generic_utils.has_arg(self.cell.call, 'training'):
|
||||||
kwargs['training'] = training
|
kwargs['training'] = training
|
||||||
|
|
||||||
|
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
|
||||||
|
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
|
||||||
if constants:
|
if constants:
|
||||||
if not generic_utils.has_arg(self.cell.call, 'constants'):
|
if not generic_utils.has_arg(self.cell.call, 'constants'):
|
||||||
raise ValueError('RNN cell does not support constants')
|
raise ValueError('RNN cell does not support constants')
|
||||||
@ -677,11 +679,21 @@ class RNN(Layer):
|
|||||||
def step(inputs, states):
|
def step(inputs, states):
|
||||||
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
|
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
|
||||||
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
|
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
|
||||||
return self.cell.call(inputs, states, constants=constants, **kwargs)
|
|
||||||
|
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||||
|
output, new_states = self.cell.call(
|
||||||
|
inputs, states, constants=constants, **kwargs)
|
||||||
|
if not nest.is_sequence(new_states):
|
||||||
|
new_states = [new_states]
|
||||||
|
return output, new_states
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def step(inputs, states):
|
def step(inputs, states):
|
||||||
return self.cell.call(inputs, states, **kwargs)
|
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||||
|
output, new_states = self.cell.call(inputs, states, **kwargs)
|
||||||
|
if not nest.is_sequence(new_states):
|
||||||
|
new_states = [new_states]
|
||||||
|
return output, new_states
|
||||||
|
|
||||||
last_output, outputs, states = K.rnn(
|
last_output, outputs, states = K.rnn(
|
||||||
step,
|
step,
|
||||||
|
@ -304,11 +304,9 @@ class RNNTest(test.TestCase):
|
|||||||
for cell in [
|
for cell in [
|
||||||
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
|
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
|
||||||
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
|
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
|
||||||
# TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras.
|
rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
|
||||||
# rnn_cell_impl.BasicRNNCell(
|
rnn_cell_impl.GRUCell(
|
||||||
# 32, activation="relu", dtype=dtypes.float32),
|
32, kernel_initializer="ones", dtype=dtypes.float32)
|
||||||
# rnn_cell_impl.GRUCell(
|
|
||||||
# 32, kernel_initializer="ones", dtype=dtypes.float32)
|
|
||||||
]:
|
]:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = keras.Input((None, 5))
|
x = keras.Input((None, 5))
|
||||||
@ -326,11 +324,13 @@ class RNNTest(test.TestCase):
|
|||||||
# not visible as a Keras layer, and also has a name conflict with
|
# not visible as a Keras layer, and also has a name conflict with
|
||||||
# keras.LSTMCell and GRUCell.
|
# keras.LSTMCell and GRUCell.
|
||||||
layer = keras.layers.RNN.from_config(
|
layer = keras.layers.RNN.from_config(
|
||||||
config, custom_objects={
|
config,
|
||||||
# "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
|
custom_objects={
|
||||||
# "GRUCell": rnn_cell_impl.GRUCell,
|
"BasicRNNCell": rnn_cell_impl.BasicRNNCell,
|
||||||
|
"GRUCell": rnn_cell_impl.GRUCell,
|
||||||
"LSTMCell": rnn_cell_impl.LSTMCell,
|
"LSTMCell": rnn_cell_impl.LSTMCell,
|
||||||
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell})
|
"BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
|
||||||
|
})
|
||||||
y = layer(x)
|
y = layer(x)
|
||||||
model = keras.models.Model(x, y)
|
model = keras.models.Model(x, y)
|
||||||
model.set_weights(weights)
|
model.set_weights(weights)
|
||||||
|
@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer):
|
|||||||
for each `s` in `self.batch_size`.
|
for each `s` in `self.batch_size`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
|
||||||
|
super(RNNCell, self).__init__(
|
||||||
|
trainable=trainable, name=name, dtype=dtype, **kwargs)
|
||||||
|
# Attribute that indicates whether the cell is a TF RNN cell, due the slight
|
||||||
|
# difference between TF and Keras RNN cell.
|
||||||
|
self._is_tf_rnn_cell = True
|
||||||
|
|
||||||
def __call__(self, inputs, state, scope=None):
|
def __call__(self, inputs, state, scope=None):
|
||||||
"""Run this RNN cell on inputs, starting from the given state.
|
"""Run this RNN cell on inputs, starting from the given state.
|
||||||
|
|
||||||
@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell):
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {
|
||||||
"num_units": self._num_units,
|
"num_units": self._num_units,
|
||||||
"initializer": initializers.serialize(self._initializer),
|
|
||||||
"kernel_initializer": initializers.serialize(self._kernel_initializer),
|
"kernel_initializer": initializers.serialize(self._kernel_initializer),
|
||||||
|
"bias_initializer": initializers.serialize(self._bias_initializer),
|
||||||
"activation": activations.serialize(self._activation),
|
"activation": activations.serialize(self._activation),
|
||||||
"reuse": self._reuse,
|
"reuse": self._reuse,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user