diff --git a/RELEASE.md b/RELEASE.md index 41149b3af6d..fd46bcdff36 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -31,6 +31,9 @@ * ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were renamed TF_ASSERT_OK, TF_EXPECT_OK. The existing macros are currently maintained for short-term compatibility but will be removed. +* The non-public `nn.rnn` and the various `nn.seq2seq` methods now return + just the final state instead of the list of all states. + ## Bug fixes diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py index 78890a4a3e0..47be980e11d 100644 --- a/tensorflow/models/rnn/ptb/ptb_word_lm.py +++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py @@ -117,16 +117,14 @@ class PTBModel(object): # from tensorflow.models.rnn import rnn # inputs = [tf.squeeze(input_, [1]) # for input_ in tf.split(1, num_steps, inputs)] - # outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state) + # outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state) outputs = [] - states = [] state = self._initial_state with tf.variable_scope("RNN"): for time_step in range(num_steps): if time_step > 0: tf.get_variable_scope().reuse_variables() (cell_output, state) = cell(inputs[:, time_step, :], state) outputs.append(cell_output) - states.append(state) output = tf.reshape(tf.concat(1, outputs), [-1, size]) softmax_w = tf.get_variable("softmax_w", [size, vocab_size]) @@ -137,7 +135,7 @@ class PTBModel(object): [tf.ones([batch_size * num_steps])], vocab_size) self._cost = cost = tf.reduce_sum(loss) / batch_size - self._final_state = states[-1] + self._final_state = state if not is_training: return diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index a8724fcd34f..c9344d63ff3 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -68,7 +68,7 @@ class RNNTest(tf.test.TestCase): max_length = 8 # unrolled up to this length inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] - outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32) + outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) for out, inp in zip(outputs, inputs): self.assertEqual(out.get_shape(), inp.get_shape()) @@ -76,7 +76,7 @@ class RNNTest(tf.test.TestCase): with self.test_session(use_gpu=False) as sess: input_value = np.random.randn(batch_size, input_size) - values = sess.run(outputs + [states[-1]], + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) # Outputs @@ -98,7 +98,7 @@ class RNNTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] with tf.variable_scope("share_scope"): - outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32) + outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) with tf.variable_scope("drop_scope"): dropped_outputs, _ = tf.nn.rnn( full_dropout_cell, inputs, dtype=tf.float32) @@ -109,7 +109,7 @@ class RNNTest(tf.test.TestCase): with self.test_session(use_gpu=False) as sess: input_value = np.random.randn(batch_size, input_size) - values = sess.run(outputs + [states[-1]], + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) full_dropout_values = sess.run(dropped_outputs, feed_dict={inputs[0]: input_value}) @@ -128,31 +128,29 @@ class RNNTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] with tf.variable_scope("drop_scope"): - dynamic_outputs, dynamic_states = tf.nn.rnn( + dynamic_outputs, dynamic_state = tf.nn.rnn( cell, inputs, sequence_length=sequence_length, dtype=tf.float32) self.assertEqual(len(dynamic_outputs), len(inputs)) - self.assertEqual(len(dynamic_states), len(inputs)) with self.test_session(use_gpu=False) as sess: input_value = np.random.randn(batch_size, input_size) dynamic_values = sess.run(dynamic_outputs, feed_dict={inputs[0]: input_value, sequence_length: [2, 3]}) - dynamic_state_values = sess.run(dynamic_states, + dynamic_state_values = sess.run([dynamic_state], feed_dict={inputs[0]: input_value, sequence_length: [2, 3]}) # fully calculated for t = 0, 1, 2 for v in dynamic_values[:3]: self.assertAllClose(v, input_value + 1.0) - for vi, v in enumerate(dynamic_state_values[:3]): - self.assertAllEqual(v, 1.0 * (vi + 1) * - np.ones((batch_size, input_size))) # zeros for t = 3+ for v in dynamic_values[3:]: self.assertAllEqual(v, np.zeros_like(input_value)) - for v in dynamic_state_values[3:]: - self.assertAllEqual(v, np.zeros_like(input_value)) + # final state is frozen from state at max(sequence_lengths) == 2 + self.assertAllEqual( + dynamic_state_values[0], + 1.0 * (2 + 1) * np.ones((batch_size, input_size))) class LSTMTest(tf.test.TestCase): @@ -219,7 +217,7 @@ class LSTMTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] with tf.variable_scope("share_scope"): - outputs, states = tf.nn.state_saving_rnn( + outputs, state = tf.nn.state_saving_rnn( cell, inputs, state_saver=state_saver, state_name="save_lstm") self.assertEqual(len(outputs), len(inputs)) for out in outputs: @@ -228,7 +226,7 @@ class LSTMTest(tf.test.TestCase): tf.initialize_all_variables().run() input_value = np.random.randn(batch_size, input_size) (last_state_value, saved_state_value) = sess.run( - [states[-1], state_saver.saved_state], + [state, state_saver.saved_state], feed_dict={inputs[0]: input_value}) self.assertAllEqual(last_state_value, saved_state_value) @@ -340,10 +338,10 @@ class LSTMTest(tf.test.TestCase): initializer=initializer, num_proj=num_proj) with tf.variable_scope("noshard_scope"): - outputs_noshard, states_noshard = tf.nn.rnn( + outputs_noshard, state_noshard = tf.nn.rnn( cell_noshard, inputs, dtype=tf.float32) with tf.variable_scope("shard_scope"): - outputs_shard, states_shard = tf.nn.rnn( + outputs_shard, state_shard = tf.nn.rnn( cell_shard, inputs, dtype=tf.float32) self.assertEqual(len(outputs_noshard), len(inputs)) @@ -354,8 +352,8 @@ class LSTMTest(tf.test.TestCase): feeds = dict((x, input_value) for x in inputs) values_noshard = sess.run(outputs_noshard, feed_dict=feeds) values_shard = sess.run(outputs_shard, feed_dict=feeds) - state_values_noshard = sess.run(states_noshard, feed_dict=feeds) - state_values_shard = sess.run(states_shard, feed_dict=feeds) + state_values_noshard = sess.run([state_noshard], feed_dict=feeds) + state_values_shard = sess.run([state_shard], feed_dict=feeds) self.assertEqual(len(values_noshard), len(values_shard)) self.assertEqual(len(state_values_noshard), len(state_values_shard)) for (v_noshard, v_shard) in zip(values_noshard, values_shard): @@ -389,22 +387,21 @@ class LSTMTest(tf.test.TestCase): initializer=initializer) dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0) - outputs, states = tf.nn.rnn( + outputs, state = tf.nn.rnn( dropout_cell, inputs, sequence_length=sequence_length, initial_state=cell.zero_state(batch_size, tf.float64)) self.assertEqual(len(outputs), len(inputs)) - self.assertEqual(len(outputs), len(states)) tf.initialize_all_variables().run(feed_dict={sequence_length: [2, 3]}) input_value = np.asarray(np.random.randn(batch_size, input_size), dtype=np.float64) values = sess.run(outputs, feed_dict={inputs[0]: input_value, sequence_length: [2, 3]}) - state_values = sess.run(states, feed_dict={inputs[0]: input_value, + state_value = sess.run([state], feed_dict={inputs[0]: input_value, sequence_length: [2, 3]}) self.assertEqual(values[0].dtype, input_value.dtype) - self.assertEqual(state_values[0].dtype, input_value.dtype) + self.assertEqual(state_value[0].dtype, input_value.dtype) def testSharingWeightsWithReuse(self): num_units = 3 diff --git a/tensorflow/python/kernel_tests/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py index b6c78234576..ff4bc0bd454 100644 --- a/tensorflow/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/python/kernel_tests/seq2seq_test.py @@ -35,19 +35,18 @@ class Seq2SeqTest(tf.test.TestCase): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - _, enc_states = tf.nn.rnn( + _, enc_state = tf.nn.rnn( tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32) dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] cell = tf.nn.rnn_cell.OutputProjectionWrapper( tf.nn.rnn_cell.GRUCell(2), 4) - dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell) + dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_state, cell) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 4)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 2)) def testBasicRNNSeq2Seq(self): @@ -63,8 +62,7 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 4)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 2)) def testTiedRNNSeq2Seq(self): @@ -80,8 +78,8 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 4)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) + self.assertEqual(len(res), 1) self.assertEqual(res[0].shape, (2, 2)) def testEmbeddingRNNDecoder(self): @@ -89,17 +87,17 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] cell = tf.nn.rnn_cell.BasicLSTMCell(2) - _, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) + _, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] - dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1], + dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_state, cell, 4) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 2)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) + self.assertEqual(len(res), 1) self.assertEqual(res[0].shape, (2, 4)) def testEmbeddingRNNSeq2Seq(self): @@ -115,8 +113,7 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 5)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 4)) # Test externally provided output projection. @@ -161,8 +158,7 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 5)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 4)) # Test externally provided output projection. @@ -198,20 +194,19 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): cell = tf.nn.rnn_cell.GRUCell(2) inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs]) dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] dec, mem = tf.nn.seq2seq.attention_decoder( - dec_inp, enc_states[-1], + dec_inp, enc_state, attn_states, cell, output_size=4) sess.run([tf.initialize_all_variables()]) res = sess.run(dec) self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 4)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 2)) def testAttentionDecoder2(self): @@ -219,12 +214,12 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): cell = tf.nn.rnn_cell.GRUCell(2) inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] - enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs]) dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)] dec, mem = tf.nn.seq2seq.attention_decoder( - dec_inp, enc_states[-1], + dec_inp, enc_state, attn_states, cell, output_size=4, num_heads=2) sess.run([tf.initialize_all_variables()]) @@ -232,8 +227,7 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 4)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 2)) def testEmbeddingAttentionDecoder(self): @@ -241,12 +235,12 @@ class Seq2SeqTest(tf.test.TestCase): with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)] cell = tf.nn.rnn_cell.GRUCell(2) - enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32) + enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs]) dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)] dec, mem = tf.nn.seq2seq.embedding_attention_decoder( - dec_inp, enc_states[-1], + dec_inp, enc_state, attn_states, cell, 4, output_size=3) sess.run([tf.initialize_all_variables()]) @@ -254,8 +248,7 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 3)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 2)) def testEmbeddingAttentionSeq2Seq(self): @@ -271,8 +264,7 @@ class Seq2SeqTest(tf.test.TestCase): self.assertEqual(len(res), 3) self.assertEqual(res[0].shape, (2, 5)) - res = sess.run(mem) - self.assertEqual(len(res), 4) + res = sess.run([mem]) self.assertEqual(res[0].shape, (2, 4)) # Test externally provided output projection. diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 710b602aa5e..cd31e8b3e1f 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -34,12 +34,10 @@ def rnn(cell, inputs, initial_state=None, dtype=None, The simplest form of RNN network generated is: state = cell.zero_state(...) outputs = [] - states = [] for input_ in inputs: output, state = cell(input_, state) outputs.append(output) - states.append(state) - return (outputs, states) + return (outputs, state) However, a few other options are available: @@ -65,9 +63,9 @@ def rnn(cell, inputs, initial_state=None, dtype=None, scope: VariableScope for the created subgraph; defaults to "RNN". Returns: - A pair (outputs, states) where: + A pair (outputs, state) where: outputs is a length T list of outputs (one for each input) - states is a length T list of states (one state following each input) + state is the final state Raises: TypeError: If "cell" is not an instance of RNNCell. @@ -82,7 +80,6 @@ def rnn(cell, inputs, initial_state=None, dtype=None, raise ValueError("inputs must not be empty") outputs = [] - states = [] with vs.variable_scope(scope or "RNN"): batch_size = array_ops.shape(inputs[0])[0] if initial_state is not None: @@ -93,30 +90,25 @@ def rnn(cell, inputs, initial_state=None, dtype=None, state = cell.zero_state(batch_size, dtype) if sequence_length: # Prepare variables - zero_output_state = ( - array_ops.zeros(array_ops.pack([batch_size, cell.output_size]), - inputs[0].dtype), - array_ops.zeros(array_ops.pack([batch_size, cell.state_size]), - state.dtype)) + zero_output = array_ops.zeros( + array_ops.pack([batch_size, cell.output_size]), inputs[0].dtype) max_sequence_length = math_ops.reduce_max(sequence_length) for time, input_ in enumerate(inputs): if time > 0: vs.get_variable_scope().reuse_variables() # pylint: disable=cell-var-from-loop - def output_state(): - return cell(input_, state) + output_state = lambda: cell(input_, state) # pylint: enable=cell-var-from-loop if sequence_length: (output, state) = control_flow_ops.cond( time >= max_sequence_length, - lambda: zero_output_state, output_state) + lambda: (zero_output, state), output_state) else: (output, state) = output_state() outputs.append(output) - states.append(state) - return (outputs, states) + return (outputs, state) def state_saving_rnn(cell, inputs, state_saver, state_name, @@ -134,22 +126,22 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, scope: VariableScope for the created subgraph; defaults to "RNN". Returns: - A pair (outputs, states) where: + A pair (outputs, state) where: outputs is a length T list of outputs (one for each input) - states is a length T list of states (one state following each input) + states is the final state Raises: TypeError: If "cell" is not an instance of RNNCell. ValueError: If inputs is None or an empty list. """ initial_state = state_saver.state(state_name) - (outputs, states) = rnn(cell, inputs, initial_state=initial_state, - sequence_length=sequence_length, scope=scope) - save_state = state_saver.save_state(state_name, states[-1]) + (outputs, state) = rnn(cell, inputs, initial_state=initial_state, + sequence_length=sequence_length, scope=scope) + save_state = state_saver.save_state(state_name, state) with ops.control_dependencies([save_state]): outputs[-1] = array_ops.identity(outputs[-1]) - return (outputs, states) + return (outputs, state) def _reverse_seq(input_seq, lengths): diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py index 79316e4eefb..860a120ffd1 100644 --- a/tensorflow/python/ops/seq2seq.py +++ b/tensorflow/python/ops/seq2seq.py @@ -54,14 +54,13 @@ def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x cell.output_size] containing generated outputs. - states: The state of each cell in each time-step. This is a list with - length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + state: The state of each cell at the final time-step. + It is a 2D Tensor of shape [batch_size x cell.state_size]. (Note that in some cases, like basic RNN cell or GRU cell, outputs and states can be the same. They are different for LSTM cells though.) """ with vs.variable_scope(scope or "rnn_decoder"): - states = [initial_state] + state = initial_state outputs = [] prev = None for i in xrange(len(decoder_inputs)): @@ -72,12 +71,11 @@ def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, inp = array_ops.stop_gradient(loop_function(prev, i)) if i > 0: vs.get_variable_scope().reuse_variables() - output, new_state = cell(inp, states[-1]) + output, state = cell(inp, state) outputs.append(output) - states.append(new_state) if loop_function is not None: prev = array_ops.stop_gradient(output) - return outputs, states + return outputs, state def basic_rnn_seq2seq( @@ -98,13 +96,12 @@ def basic_rnn_seq2seq( Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x cell.output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + state: The state of each decoder cell in the final time-step. + It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with vs.variable_scope(scope or "basic_rnn_seq2seq"): - _, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype) - return rnn_decoder(decoder_inputs, enc_states[-1], cell) + _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype) + return rnn_decoder(decoder_inputs, enc_state, cell) def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, @@ -128,16 +125,16 @@ def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x cell.output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list + state: The state of each decoder cell in each time-step. This is a list with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with vs.variable_scope("combined_tied_rnn_seq2seq"): scope = scope or "tied_rnn_seq2seq" - _, enc_states = rnn.rnn( + _, enc_state = rnn.rnn( cell, encoder_inputs, dtype=dtype, scope=scope) vs.get_variable_scope().reuse_variables() - return rnn_decoder(decoder_inputs, enc_states[-1], cell, + return rnn_decoder(decoder_inputs, enc_state, cell, loop_function=loop_function, scope=scope) @@ -167,9 +164,9 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x cell.output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list + state: The state of each decoder cell in each time-step. This is a list with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + It is a 2D Tensor of shape [batch_size x cell.state_size]. Raises: ValueError: when output_projection has the wrong shape. @@ -240,37 +237,37 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x num_decoder_symbols] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list + state: The state of each decoder cell in each time-step. This is a list with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with vs.variable_scope(scope or "embedding_rnn_seq2seq"): # Encoder. encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) - _, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) + _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) # Decoder. if output_projection is None: cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) if isinstance(feed_previous, bool): - return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell, + return embedding_rnn_decoder(decoder_inputs, encoder_state, cell, num_decoder_symbols, output_projection, feed_previous) else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. - outputs1, states1 = embedding_rnn_decoder( - decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, + outputs1, state1 = embedding_rnn_decoder( + decoder_inputs, encoder_state, cell, num_decoder_symbols, output_projection, True) vs.get_variable_scope().reuse_variables() - outputs2, states2 = embedding_rnn_decoder( - decoder_inputs, encoder_states[-1], cell, num_decoder_symbols, + outputs2, state2 = embedding_rnn_decoder( + decoder_inputs, encoder_state, cell, num_decoder_symbols, output_projection, False) outputs = control_flow_ops.cond(feed_previous, lambda: outputs1, lambda: outputs2) - states = control_flow_ops.cond(feed_previous, - lambda: states1, lambda: states2) - return outputs, states + state = control_flow_ops.cond(feed_previous, + lambda: state1, lambda: state2) + return outputs, state def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, @@ -305,9 +302,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x num_decoder_symbols] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + state: The state of each decoder cell at the final time-step. + It is a 2D Tensor of shape [batch_size x cell.state_size]. Raises: ValueError: when output_projection has the wrong shape. @@ -344,18 +340,18 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, loop_function=loop_function, dtype=dtype) else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. - outputs1, states1 = tied_rnn_seq2seq( + outputs1, state1 = tied_rnn_seq2seq( emb_encoder_inputs, emb_decoder_inputs, cell, loop_function=extract_argmax_and_embed, dtype=dtype) vs.get_variable_scope().reuse_variables() - outputs2, states2 = tied_rnn_seq2seq( + outputs2, state2 = tied_rnn_seq2seq( emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype) outputs = control_flow_ops.cond(feed_previous, lambda: outputs1, lambda: outputs2) - states = control_flow_ops.cond(feed_previous, - lambda: states1, lambda: states2) - return outputs, states + state = control_flow_ops.cond(feed_previous, + lambda: state1, lambda: state2) + return outputs, state def attention_decoder(decoder_inputs, initial_state, attention_states, cell, @@ -397,9 +393,8 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell, new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) and then we calculate the output: output = linear(cell_output, new_attn). - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + state: The state of each decoder cell the final time-step. + It is a 2D Tensor of shape [batch_size x cell.state_size]. Raises: ValueError: when num_heads is not positive, there are no inputs, or shapes @@ -431,7 +426,7 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell, hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) v.append(vs.get_variable("AttnV_%d" % a, [attention_vec_size])) - states = [initial_state] + state = initial_state def attention(query): """Put attention masks on hidden using hidden_features and query.""" @@ -471,14 +466,13 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell, # Merge input and previous attentions into one vector of the right size. x = rnn_cell.linear([inp] + attns, cell.input_size, True) # Run the RNN. - cell_output, new_state = cell(x, states[-1]) - states.append(new_state) + cell_output, state = cell(x, state) # Run the attention mechanism. if i == 0 and initial_state_attention: with vs.variable_scope(vs.get_variable_scope(), reuse=True): - attns = attention(new_state) + attns = attention(state) else: - attns = attention(new_state) + attns = attention(state) with vs.variable_scope("AttnOutputProjection"): output = rnn_cell.linear([cell_output] + attns, output_size, True) @@ -487,7 +481,7 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell, prev = array_ops.stop_gradient(output) outputs.append(output) - return outputs, states + return outputs, state def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, @@ -526,9 +520,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x output_size] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + state: The state of each decoder cell at the final time-step. + It is a 2D Tensor of shape [batch_size x cell.state_size]. Raises: ValueError: when output_projection has the wrong shape. @@ -607,14 +600,13 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x num_decoder_symbols] containing the generated outputs. - states: The state of each decoder cell in each time-step. This is a list - with length len(decoder_inputs) -- one item for each time-step. - Each item is a 2D Tensor of shape [batch_size x cell.state_size]. + state: The state of each decoder cell at the final time-step. + It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with vs.variable_scope(scope or "embedding_attention_seq2seq"): # Encoder. encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols) - encoder_outputs, encoder_states = rnn.rnn( + encoder_outputs, encoder_state = rnn.rnn( encoder_cell, encoder_inputs, dtype=dtype) # First calculate a concatenation of encoder outputs to put attention on. @@ -630,25 +622,25 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell, if isinstance(feed_previous, bool): return embedding_attention_decoder( - decoder_inputs, encoder_states[-1], attention_states, cell, + decoder_inputs, encoder_state, attention_states, cell, num_decoder_symbols, num_heads, output_size, output_projection, feed_previous, initial_state_attention=initial_state_attention) else: # If feed_previous is a Tensor, we construct 2 graphs and use cond. - outputs1, states1 = embedding_attention_decoder( - decoder_inputs, encoder_states[-1], attention_states, cell, + outputs1, state1 = embedding_attention_decoder( + decoder_inputs, encoder_state, attention_states, cell, num_decoder_symbols, num_heads, output_size, output_projection, True, initial_state_attention=initial_state_attention) vs.get_variable_scope().reuse_variables() - outputs2, states2 = embedding_attention_decoder( - decoder_inputs, encoder_states[-1], attention_states, cell, + outputs2, state2 = embedding_attention_decoder( + decoder_inputs, encoder_state, attention_states, cell, num_decoder_symbols, num_heads, output_size, output_projection, False, initial_state_attention=initial_state_attention) outputs = control_flow_ops.cond(feed_previous, lambda: outputs1, lambda: outputs2) - states = control_flow_ops.cond(feed_previous, - lambda: states1, lambda: states2) - return outputs, states + state = control_flow_ops.cond(feed_previous, + lambda: state1, lambda: state2) + return outputs, state def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,