Breaking change in TF RNN python api: Return the final state instead of the

list of states when calling tf.nn.rnn() and tf.nn.state_saving_rnn()

This is necessary for further cleanup of RNN state propagation code
(currently dynamic RNN calculations when passing sequence_length do not return
the proper final state, this is a necessary fix to make that fix efficient).
Change: 113203893
This commit is contained in:
Eugene Brevdo 2016-01-27 14:54:54 -08:00 committed by Vijay Vasudevan
parent e59493941c
commit fea55e1e05
6 changed files with 113 additions and 139 deletions

View File

@ -31,6 +31,9 @@
* ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were
renamed TF_ASSERT_OK, TF_EXPECT_OK. The existing macros are currently
maintained for short-term compatibility but will be removed.
* The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
just the final state instead of the list of all states.
## Bug fixes

View File

@ -117,16 +117,14 @@ class PTBModel(object):
# from tensorflow.models.rnn import rnn
# inputs = [tf.squeeze(input_, [1])
# for input_ in tf.split(1, num_steps, inputs)]
# outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)
# outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
outputs = []
states = []
state = self._initial_state
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
states.append(state)
output = tf.reshape(tf.concat(1, outputs), [-1, size])
softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
@ -137,7 +135,7 @@ class PTBModel(object):
[tf.ones([batch_size * num_steps])],
vocab_size)
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = states[-1]
self._final_state = state
if not is_training:
return

View File

@ -68,7 +68,7 @@ class RNNTest(tf.test.TestCase):
max_length = 8 # unrolled up to this length
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
for out, inp in zip(outputs, inputs):
self.assertEqual(out.get_shape(), inp.get_shape())
@ -76,7 +76,7 @@ class RNNTest(tf.test.TestCase):
with self.test_session(use_gpu=False) as sess:
input_value = np.random.randn(batch_size, input_size)
values = sess.run(outputs + [states[-1]],
values = sess.run(outputs + [state],
feed_dict={inputs[0]: input_value})
# Outputs
@ -98,7 +98,7 @@ class RNNTest(tf.test.TestCase):
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("share_scope"):
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("drop_scope"):
dropped_outputs, _ = tf.nn.rnn(
full_dropout_cell, inputs, dtype=tf.float32)
@ -109,7 +109,7 @@ class RNNTest(tf.test.TestCase):
with self.test_session(use_gpu=False) as sess:
input_value = np.random.randn(batch_size, input_size)
values = sess.run(outputs + [states[-1]],
values = sess.run(outputs + [state],
feed_dict={inputs[0]: input_value})
full_dropout_values = sess.run(dropped_outputs,
feed_dict={inputs[0]: input_value})
@ -128,31 +128,29 @@ class RNNTest(tf.test.TestCase):
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("drop_scope"):
dynamic_outputs, dynamic_states = tf.nn.rnn(
dynamic_outputs, dynamic_state = tf.nn.rnn(
cell, inputs, sequence_length=sequence_length, dtype=tf.float32)
self.assertEqual(len(dynamic_outputs), len(inputs))
self.assertEqual(len(dynamic_states), len(inputs))
with self.test_session(use_gpu=False) as sess:
input_value = np.random.randn(batch_size, input_size)
dynamic_values = sess.run(dynamic_outputs,
feed_dict={inputs[0]: input_value,
sequence_length: [2, 3]})
dynamic_state_values = sess.run(dynamic_states,
dynamic_state_values = sess.run([dynamic_state],
feed_dict={inputs[0]: input_value,
sequence_length: [2, 3]})
# fully calculated for t = 0, 1, 2
for v in dynamic_values[:3]:
self.assertAllClose(v, input_value + 1.0)
for vi, v in enumerate(dynamic_state_values[:3]):
self.assertAllEqual(v, 1.0 * (vi + 1) *
np.ones((batch_size, input_size)))
# zeros for t = 3+
for v in dynamic_values[3:]:
self.assertAllEqual(v, np.zeros_like(input_value))
for v in dynamic_state_values[3:]:
self.assertAllEqual(v, np.zeros_like(input_value))
# final state is frozen from state at max(sequence_lengths) == 2
self.assertAllEqual(
dynamic_state_values[0],
1.0 * (2 + 1) * np.ones((batch_size, input_size)))
class LSTMTest(tf.test.TestCase):
@ -219,7 +217,7 @@ class LSTMTest(tf.test.TestCase):
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("share_scope"):
outputs, states = tf.nn.state_saving_rnn(
outputs, state = tf.nn.state_saving_rnn(
cell, inputs, state_saver=state_saver, state_name="save_lstm")
self.assertEqual(len(outputs), len(inputs))
for out in outputs:
@ -228,7 +226,7 @@ class LSTMTest(tf.test.TestCase):
tf.initialize_all_variables().run()
input_value = np.random.randn(batch_size, input_size)
(last_state_value, saved_state_value) = sess.run(
[states[-1], state_saver.saved_state],
[state, state_saver.saved_state],
feed_dict={inputs[0]: input_value})
self.assertAllEqual(last_state_value, saved_state_value)
@ -340,10 +338,10 @@ class LSTMTest(tf.test.TestCase):
initializer=initializer, num_proj=num_proj)
with tf.variable_scope("noshard_scope"):
outputs_noshard, states_noshard = tf.nn.rnn(
outputs_noshard, state_noshard = tf.nn.rnn(
cell_noshard, inputs, dtype=tf.float32)
with tf.variable_scope("shard_scope"):
outputs_shard, states_shard = tf.nn.rnn(
outputs_shard, state_shard = tf.nn.rnn(
cell_shard, inputs, dtype=tf.float32)
self.assertEqual(len(outputs_noshard), len(inputs))
@ -354,8 +352,8 @@ class LSTMTest(tf.test.TestCase):
feeds = dict((x, input_value) for x in inputs)
values_noshard = sess.run(outputs_noshard, feed_dict=feeds)
values_shard = sess.run(outputs_shard, feed_dict=feeds)
state_values_noshard = sess.run(states_noshard, feed_dict=feeds)
state_values_shard = sess.run(states_shard, feed_dict=feeds)
state_values_noshard = sess.run([state_noshard], feed_dict=feeds)
state_values_shard = sess.run([state_shard], feed_dict=feeds)
self.assertEqual(len(values_noshard), len(values_shard))
self.assertEqual(len(state_values_noshard), len(state_values_shard))
for (v_noshard, v_shard) in zip(values_noshard, values_shard):
@ -389,22 +387,21 @@ class LSTMTest(tf.test.TestCase):
initializer=initializer)
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
outputs, states = tf.nn.rnn(
outputs, state = tf.nn.rnn(
dropout_cell, inputs, sequence_length=sequence_length,
initial_state=cell.zero_state(batch_size, tf.float64))
self.assertEqual(len(outputs), len(inputs))
self.assertEqual(len(outputs), len(states))
tf.initialize_all_variables().run(feed_dict={sequence_length: [2, 3]})
input_value = np.asarray(np.random.randn(batch_size, input_size),
dtype=np.float64)
values = sess.run(outputs, feed_dict={inputs[0]: input_value,
sequence_length: [2, 3]})
state_values = sess.run(states, feed_dict={inputs[0]: input_value,
state_value = sess.run([state], feed_dict={inputs[0]: input_value,
sequence_length: [2, 3]})
self.assertEqual(values[0].dtype, input_value.dtype)
self.assertEqual(state_values[0].dtype, input_value.dtype)
self.assertEqual(state_value[0].dtype, input_value.dtype)
def testSharingWeightsWithReuse(self):
num_units = 3

View File

@ -35,19 +35,18 @@ class Seq2SeqTest(tf.test.TestCase):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
_, enc_states = tf.nn.rnn(
_, enc_state = tf.nn.rnn(
tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32)
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(2), 4)
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell)
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_state, cell)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 4))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 2))
def testBasicRNNSeq2Seq(self):
@ -63,8 +62,7 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 4))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 2))
def testTiedRNNSeq2Seq(self):
@ -80,8 +78,8 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 4))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(len(res), 1)
self.assertEqual(res[0].shape, (2, 2))
def testEmbeddingRNNDecoder(self):
@ -89,17 +87,17 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
_, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
_, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1],
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_state,
cell, 4)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 2))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(len(res), 1)
self.assertEqual(res[0].shape, (2, 4))
def testEmbeddingRNNSeq2Seq(self):
@ -115,8 +113,7 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 5))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 4))
# Test externally provided output projection.
@ -161,8 +158,7 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 5))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 4))
# Test externally provided output projection.
@ -198,20 +194,19 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = tf.nn.rnn_cell.GRUCell(2)
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs])
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
dec, mem = tf.nn.seq2seq.attention_decoder(
dec_inp, enc_states[-1],
dec_inp, enc_state,
attn_states, cell, output_size=4)
sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 4))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 2))
def testAttentionDecoder2(self):
@ -219,12 +214,12 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
cell = tf.nn.rnn_cell.GRUCell(2)
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs])
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
dec, mem = tf.nn.seq2seq.attention_decoder(
dec_inp, enc_states[-1],
dec_inp, enc_state,
attn_states, cell, output_size=4,
num_heads=2)
sess.run([tf.initialize_all_variables()])
@ -232,8 +227,7 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 4))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 2))
def testEmbeddingAttentionDecoder(self):
@ -241,12 +235,12 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
cell = tf.nn.rnn_cell.GRUCell(2)
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
for e in enc_outputs])
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
dec, mem = tf.nn.seq2seq.embedding_attention_decoder(
dec_inp, enc_states[-1],
dec_inp, enc_state,
attn_states, cell, 4,
output_size=3)
sess.run([tf.initialize_all_variables()])
@ -254,8 +248,7 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 3))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 2))
def testEmbeddingAttentionSeq2Seq(self):
@ -271,8 +264,7 @@ class Seq2SeqTest(tf.test.TestCase):
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 5))
res = sess.run(mem)
self.assertEqual(len(res), 4)
res = sess.run([mem])
self.assertEqual(res[0].shape, (2, 4))
# Test externally provided output projection.

View File

@ -34,12 +34,10 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
The simplest form of RNN network generated is:
state = cell.zero_state(...)
outputs = []
states = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
states.append(state)
return (outputs, states)
return (outputs, state)
However, a few other options are available:
@ -65,9 +63,9 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
scope: VariableScope for the created subgraph; defaults to "RNN".
Returns:
A pair (outputs, states) where:
A pair (outputs, state) where:
outputs is a length T list of outputs (one for each input)
states is a length T list of states (one state following each input)
state is the final state
Raises:
TypeError: If "cell" is not an instance of RNNCell.
@ -82,7 +80,6 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
raise ValueError("inputs must not be empty")
outputs = []
states = []
with vs.variable_scope(scope or "RNN"):
batch_size = array_ops.shape(inputs[0])[0]
if initial_state is not None:
@ -93,30 +90,25 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
state = cell.zero_state(batch_size, dtype)
if sequence_length: # Prepare variables
zero_output_state = (
array_ops.zeros(array_ops.pack([batch_size, cell.output_size]),
inputs[0].dtype),
array_ops.zeros(array_ops.pack([batch_size, cell.state_size]),
state.dtype))
zero_output = array_ops.zeros(
array_ops.pack([batch_size, cell.output_size]), inputs[0].dtype)
max_sequence_length = math_ops.reduce_max(sequence_length)
for time, input_ in enumerate(inputs):
if time > 0: vs.get_variable_scope().reuse_variables()
# pylint: disable=cell-var-from-loop
def output_state():
return cell(input_, state)
output_state = lambda: cell(input_, state)
# pylint: enable=cell-var-from-loop
if sequence_length:
(output, state) = control_flow_ops.cond(
time >= max_sequence_length,
lambda: zero_output_state, output_state)
lambda: (zero_output, state), output_state)
else:
(output, state) = output_state()
outputs.append(output)
states.append(state)
return (outputs, states)
return (outputs, state)
def state_saving_rnn(cell, inputs, state_saver, state_name,
@ -134,22 +126,22 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
scope: VariableScope for the created subgraph; defaults to "RNN".
Returns:
A pair (outputs, states) where:
A pair (outputs, state) where:
outputs is a length T list of outputs (one for each input)
states is a length T list of states (one state following each input)
states is the final state
Raises:
TypeError: If "cell" is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
"""
initial_state = state_saver.state(state_name)
(outputs, states) = rnn(cell, inputs, initial_state=initial_state,
sequence_length=sequence_length, scope=scope)
save_state = state_saver.save_state(state_name, states[-1])
(outputs, state) = rnn(cell, inputs, initial_state=initial_state,
sequence_length=sequence_length, scope=scope)
save_state = state_saver.save_state(state_name, state)
with ops.control_dependencies([save_state]):
outputs[-1] = array_ops.identity(outputs[-1])
return (outputs, states)
return (outputs, state)
def _reverse_seq(input_seq, lengths):

View File

@ -54,14 +54,13 @@ def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing generated outputs.
states: The state of each cell in each time-step. This is a list with
length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
state: The state of each cell at the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
(Note that in some cases, like basic RNN cell or GRU cell, outputs and
states can be the same. They are different for LSTM cells though.)
"""
with vs.variable_scope(scope or "rnn_decoder"):
states = [initial_state]
state = initial_state
outputs = []
prev = None
for i in xrange(len(decoder_inputs)):
@ -72,12 +71,11 @@ def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,
inp = array_ops.stop_gradient(loop_function(prev, i))
if i > 0:
vs.get_variable_scope().reuse_variables()
output, new_state = cell(inp, states[-1])
output, state = cell(inp, state)
outputs.append(output)
states.append(new_state)
if loop_function is not None:
prev = array_ops.stop_gradient(output)
return outputs, states
return outputs, state
def basic_rnn_seq2seq(
@ -98,13 +96,12 @@ def basic_rnn_seq2seq(
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
state: The state of each decoder cell in the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with vs.variable_scope(scope or "basic_rnn_seq2seq"):
_, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_states[-1], cell)
_, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
@ -128,16 +125,16 @@ def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
state: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with vs.variable_scope("combined_tied_rnn_seq2seq"):
scope = scope or "tied_rnn_seq2seq"
_, enc_states = rnn.rnn(
_, enc_state = rnn.rnn(
cell, encoder_inputs, dtype=dtype, scope=scope)
vs.get_variable_scope().reuse_variables()
return rnn_decoder(decoder_inputs, enc_states[-1], cell,
return rnn_decoder(decoder_inputs, enc_state, cell,
loop_function=loop_function, scope=scope)
@ -167,9 +164,9 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x cell.output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
state: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
It is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when output_projection has the wrong shape.
@ -240,37 +237,37 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
state: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with vs.variable_scope(scope or "embedding_rnn_seq2seq"):
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
_, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
_, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell,
return embedding_rnn_decoder(decoder_inputs, encoder_state, cell,
num_decoder_symbols, output_projection,
feed_previous)
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
outputs1, states1 = embedding_rnn_decoder(
decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
outputs1, state1 = embedding_rnn_decoder(
decoder_inputs, encoder_state, cell, num_decoder_symbols,
output_projection, True)
vs.get_variable_scope().reuse_variables()
outputs2, states2 = embedding_rnn_decoder(
decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
outputs2, state2 = embedding_rnn_decoder(
decoder_inputs, encoder_state, cell, num_decoder_symbols,
output_projection, False)
outputs = control_flow_ops.cond(feed_previous,
lambda: outputs1, lambda: outputs2)
states = control_flow_ops.cond(feed_previous,
lambda: states1, lambda: states2)
return outputs, states
state = control_flow_ops.cond(feed_previous,
lambda: state1, lambda: state2)
return outputs, state
def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
@ -305,9 +302,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
state: The state of each decoder cell at the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when output_projection has the wrong shape.
@ -344,18 +340,18 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell,
loop_function=loop_function, dtype=dtype)
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
outputs1, states1 = tied_rnn_seq2seq(
outputs1, state1 = tied_rnn_seq2seq(
emb_encoder_inputs, emb_decoder_inputs, cell,
loop_function=extract_argmax_and_embed, dtype=dtype)
vs.get_variable_scope().reuse_variables()
outputs2, states2 = tied_rnn_seq2seq(
outputs2, state2 = tied_rnn_seq2seq(
emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype)
outputs = control_flow_ops.cond(feed_previous,
lambda: outputs1, lambda: outputs2)
states = control_flow_ops.cond(feed_previous,
lambda: states1, lambda: states2)
return outputs, states
state = control_flow_ops.cond(feed_previous,
lambda: state1, lambda: state2)
return outputs, state
def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
@ -397,9 +393,8 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
and then we calculate the output:
output = linear(cell_output, new_attn).
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
state: The state of each decoder cell the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when num_heads is not positive, there are no inputs, or shapes
@ -431,7 +426,7 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
v.append(vs.get_variable("AttnV_%d" % a, [attention_vec_size]))
states = [initial_state]
state = initial_state
def attention(query):
"""Put attention masks on hidden using hidden_features and query."""
@ -471,14 +466,13 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
# Merge input and previous attentions into one vector of the right size.
x = rnn_cell.linear([inp] + attns, cell.input_size, True)
# Run the RNN.
cell_output, new_state = cell(x, states[-1])
states.append(new_state)
cell_output, state = cell(x, state)
# Run the attention mechanism.
if i == 0 and initial_state_attention:
with vs.variable_scope(vs.get_variable_scope(), reuse=True):
attns = attention(new_state)
attns = attention(state)
else:
attns = attention(new_state)
attns = attention(state)
with vs.variable_scope("AttnOutputProjection"):
output = rnn_cell.linear([cell_output] + attns, output_size, True)
@ -487,7 +481,7 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
prev = array_ops.stop_gradient(output)
outputs.append(output)
return outputs, states
return outputs, state
def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
@ -526,9 +520,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x output_size] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
state: The state of each decoder cell at the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
Raises:
ValueError: when output_projection has the wrong shape.
@ -607,14 +600,13 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
Returns:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated outputs.
states: The state of each decoder cell in each time-step. This is a list
with length len(decoder_inputs) -- one item for each time-step.
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
state: The state of each decoder cell at the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with vs.variable_scope(scope or "embedding_attention_seq2seq"):
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
encoder_outputs, encoder_states = rnn.rnn(
encoder_outputs, encoder_state = rnn.rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
@ -630,25 +622,25 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
if isinstance(feed_previous, bool):
return embedding_attention_decoder(
decoder_inputs, encoder_states[-1], attention_states, cell,
decoder_inputs, encoder_state, attention_states, cell,
num_decoder_symbols, num_heads, output_size, output_projection,
feed_previous, initial_state_attention=initial_state_attention)
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
outputs1, states1 = embedding_attention_decoder(
decoder_inputs, encoder_states[-1], attention_states, cell,
outputs1, state1 = embedding_attention_decoder(
decoder_inputs, encoder_state, attention_states, cell,
num_decoder_symbols, num_heads, output_size, output_projection, True,
initial_state_attention=initial_state_attention)
vs.get_variable_scope().reuse_variables()
outputs2, states2 = embedding_attention_decoder(
decoder_inputs, encoder_states[-1], attention_states, cell,
outputs2, state2 = embedding_attention_decoder(
decoder_inputs, encoder_state, attention_states, cell,
num_decoder_symbols, num_heads, output_size, output_projection, False,
initial_state_attention=initial_state_attention)
outputs = control_flow_ops.cond(feed_previous,
lambda: outputs1, lambda: outputs2)
states = control_flow_ops.cond(feed_previous,
lambda: states1, lambda: states2)
return outputs, states
state = control_flow_ops.cond(feed_previous,
lambda: state1, lambda: state2)
return outputs, state
def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,