Breaking change in TF RNN python api: Return the final state instead of the
list of states when calling tf.nn.rnn() and tf.nn.state_saving_rnn() This is necessary for further cleanup of RNN state propagation code (currently dynamic RNN calculations when passing sequence_length do not return the proper final state, this is a necessary fix to make that fix efficient). Change: 113203893
This commit is contained in:
parent
e59493941c
commit
fea55e1e05
@ -31,6 +31,9 @@
|
||||
* ASSERT_OK / EXPECT_OK macros conflicted with external projects, so they were
|
||||
renamed TF_ASSERT_OK, TF_EXPECT_OK. The existing macros are currently
|
||||
maintained for short-term compatibility but will be removed.
|
||||
* The non-public `nn.rnn` and the various `nn.seq2seq` methods now return
|
||||
just the final state instead of the list of all states.
|
||||
|
||||
|
||||
## Bug fixes
|
||||
|
||||
|
@ -117,16 +117,14 @@ class PTBModel(object):
|
||||
# from tensorflow.models.rnn import rnn
|
||||
# inputs = [tf.squeeze(input_, [1])
|
||||
# for input_ in tf.split(1, num_steps, inputs)]
|
||||
# outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)
|
||||
# outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
|
||||
outputs = []
|
||||
states = []
|
||||
state = self._initial_state
|
||||
with tf.variable_scope("RNN"):
|
||||
for time_step in range(num_steps):
|
||||
if time_step > 0: tf.get_variable_scope().reuse_variables()
|
||||
(cell_output, state) = cell(inputs[:, time_step, :], state)
|
||||
outputs.append(cell_output)
|
||||
states.append(state)
|
||||
|
||||
output = tf.reshape(tf.concat(1, outputs), [-1, size])
|
||||
softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
|
||||
@ -137,7 +135,7 @@ class PTBModel(object):
|
||||
[tf.ones([batch_size * num_steps])],
|
||||
vocab_size)
|
||||
self._cost = cost = tf.reduce_sum(loss) / batch_size
|
||||
self._final_state = states[-1]
|
||||
self._final_state = state
|
||||
|
||||
if not is_training:
|
||||
return
|
||||
|
@ -68,7 +68,7 @@ class RNNTest(tf.test.TestCase):
|
||||
max_length = 8 # unrolled up to this length
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out, inp in zip(outputs, inputs):
|
||||
self.assertEqual(out.get_shape(), inp.get_shape())
|
||||
@ -76,7 +76,7 @@ class RNNTest(tf.test.TestCase):
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
values = sess.run(outputs + [states[-1]],
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
|
||||
# Outputs
|
||||
@ -98,7 +98,7 @@ class RNNTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
with tf.variable_scope("share_scope"):
|
||||
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
with tf.variable_scope("drop_scope"):
|
||||
dropped_outputs, _ = tf.nn.rnn(
|
||||
full_dropout_cell, inputs, dtype=tf.float32)
|
||||
@ -109,7 +109,7 @@ class RNNTest(tf.test.TestCase):
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
values = sess.run(outputs + [states[-1]],
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
full_dropout_values = sess.run(dropped_outputs,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
@ -128,31 +128,29 @@ class RNNTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
with tf.variable_scope("drop_scope"):
|
||||
dynamic_outputs, dynamic_states = tf.nn.rnn(
|
||||
dynamic_outputs, dynamic_state = tf.nn.rnn(
|
||||
cell, inputs, sequence_length=sequence_length, dtype=tf.float32)
|
||||
self.assertEqual(len(dynamic_outputs), len(inputs))
|
||||
self.assertEqual(len(dynamic_states), len(inputs))
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
dynamic_values = sess.run(dynamic_outputs,
|
||||
feed_dict={inputs[0]: input_value,
|
||||
sequence_length: [2, 3]})
|
||||
dynamic_state_values = sess.run(dynamic_states,
|
||||
dynamic_state_values = sess.run([dynamic_state],
|
||||
feed_dict={inputs[0]: input_value,
|
||||
sequence_length: [2, 3]})
|
||||
|
||||
# fully calculated for t = 0, 1, 2
|
||||
for v in dynamic_values[:3]:
|
||||
self.assertAllClose(v, input_value + 1.0)
|
||||
for vi, v in enumerate(dynamic_state_values[:3]):
|
||||
self.assertAllEqual(v, 1.0 * (vi + 1) *
|
||||
np.ones((batch_size, input_size)))
|
||||
# zeros for t = 3+
|
||||
for v in dynamic_values[3:]:
|
||||
self.assertAllEqual(v, np.zeros_like(input_value))
|
||||
for v in dynamic_state_values[3:]:
|
||||
self.assertAllEqual(v, np.zeros_like(input_value))
|
||||
# final state is frozen from state at max(sequence_lengths) == 2
|
||||
self.assertAllEqual(
|
||||
dynamic_state_values[0],
|
||||
1.0 * (2 + 1) * np.ones((batch_size, input_size)))
|
||||
|
||||
|
||||
class LSTMTest(tf.test.TestCase):
|
||||
@ -219,7 +217,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
with tf.variable_scope("share_scope"):
|
||||
outputs, states = tf.nn.state_saving_rnn(
|
||||
outputs, state = tf.nn.state_saving_rnn(
|
||||
cell, inputs, state_saver=state_saver, state_name="save_lstm")
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
for out in outputs:
|
||||
@ -228,7 +226,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.initialize_all_variables().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
(last_state_value, saved_state_value) = sess.run(
|
||||
[states[-1], state_saver.saved_state],
|
||||
[state, state_saver.saved_state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(last_state_value, saved_state_value)
|
||||
|
||||
@ -340,10 +338,10 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
|
||||
with tf.variable_scope("noshard_scope"):
|
||||
outputs_noshard, states_noshard = tf.nn.rnn(
|
||||
outputs_noshard, state_noshard = tf.nn.rnn(
|
||||
cell_noshard, inputs, dtype=tf.float32)
|
||||
with tf.variable_scope("shard_scope"):
|
||||
outputs_shard, states_shard = tf.nn.rnn(
|
||||
outputs_shard, state_shard = tf.nn.rnn(
|
||||
cell_shard, inputs, dtype=tf.float32)
|
||||
|
||||
self.assertEqual(len(outputs_noshard), len(inputs))
|
||||
@ -354,8 +352,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
feeds = dict((x, input_value) for x in inputs)
|
||||
values_noshard = sess.run(outputs_noshard, feed_dict=feeds)
|
||||
values_shard = sess.run(outputs_shard, feed_dict=feeds)
|
||||
state_values_noshard = sess.run(states_noshard, feed_dict=feeds)
|
||||
state_values_shard = sess.run(states_shard, feed_dict=feeds)
|
||||
state_values_noshard = sess.run([state_noshard], feed_dict=feeds)
|
||||
state_values_shard = sess.run([state_shard], feed_dict=feeds)
|
||||
self.assertEqual(len(values_noshard), len(values_shard))
|
||||
self.assertEqual(len(state_values_noshard), len(state_values_shard))
|
||||
for (v_noshard, v_shard) in zip(values_noshard, values_shard):
|
||||
@ -389,22 +387,21 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer=initializer)
|
||||
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
|
||||
|
||||
outputs, states = tf.nn.rnn(
|
||||
outputs, state = tf.nn.rnn(
|
||||
dropout_cell, inputs, sequence_length=sequence_length,
|
||||
initial_state=cell.zero_state(batch_size, tf.float64))
|
||||
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
self.assertEqual(len(outputs), len(states))
|
||||
|
||||
tf.initialize_all_variables().run(feed_dict={sequence_length: [2, 3]})
|
||||
input_value = np.asarray(np.random.randn(batch_size, input_size),
|
||||
dtype=np.float64)
|
||||
values = sess.run(outputs, feed_dict={inputs[0]: input_value,
|
||||
sequence_length: [2, 3]})
|
||||
state_values = sess.run(states, feed_dict={inputs[0]: input_value,
|
||||
state_value = sess.run([state], feed_dict={inputs[0]: input_value,
|
||||
sequence_length: [2, 3]})
|
||||
self.assertEqual(values[0].dtype, input_value.dtype)
|
||||
self.assertEqual(state_values[0].dtype, input_value.dtype)
|
||||
self.assertEqual(state_value[0].dtype, input_value.dtype)
|
||||
|
||||
def testSharingWeightsWithReuse(self):
|
||||
num_units = 3
|
||||
|
@ -35,19 +35,18 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
|
||||
_, enc_states = tf.nn.rnn(
|
||||
_, enc_state = tf.nn.rnn(
|
||||
tf.nn.rnn_cell.GRUCell(2), inp, dtype=tf.float32)
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
|
||||
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
|
||||
tf.nn.rnn_cell.GRUCell(2), 4)
|
||||
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_states[-1], cell)
|
||||
dec, mem = tf.nn.seq2seq.rnn_decoder(dec_inp, enc_state, cell)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
def testBasicRNNSeq2Seq(self):
|
||||
@ -63,8 +62,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
def testTiedRNNSeq2Seq(self):
|
||||
@ -80,8 +78,8 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
def testEmbeddingRNNDecoder(self):
|
||||
@ -89,17 +87,17 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
|
||||
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
|
||||
_, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
_, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
|
||||
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_states[-1],
|
||||
dec, mem = tf.nn.seq2seq.embedding_rnn_decoder(dec_inp, enc_state,
|
||||
cell, 4)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
def testEmbeddingRNNSeq2Seq(self):
|
||||
@ -115,8 +113,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 5))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
# Test externally provided output projection.
|
||||
@ -161,8 +158,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 5))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
# Test externally provided output projection.
|
||||
@ -198,20 +194,19 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
|
||||
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_states[-1],
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run(dec)
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
def testAttentionDecoder2(self):
|
||||
@ -219,12 +214,12 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
|
||||
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(0.4, shape=[2, 2]) for _ in xrange(3)]
|
||||
dec, mem = tf.nn.seq2seq.attention_decoder(
|
||||
dec_inp, enc_states[-1],
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, output_size=4,
|
||||
num_heads=2)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
@ -232,8 +227,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
def testEmbeddingAttentionDecoder(self):
|
||||
@ -241,12 +235,12 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inp = [tf.constant(0.5, shape=[2, 2]) for _ in xrange(2)]
|
||||
cell = tf.nn.rnn_cell.GRUCell(2)
|
||||
enc_outputs, enc_states = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
enc_outputs, enc_state = tf.nn.rnn(cell, inp, dtype=tf.float32)
|
||||
attn_states = tf.concat(1, [tf.reshape(e, [-1, 1, cell.output_size])
|
||||
for e in enc_outputs])
|
||||
dec_inp = [tf.constant(i, tf.int32, shape=[2]) for i in xrange(3)]
|
||||
dec, mem = tf.nn.seq2seq.embedding_attention_decoder(
|
||||
dec_inp, enc_states[-1],
|
||||
dec_inp, enc_state,
|
||||
attn_states, cell, 4,
|
||||
output_size=3)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
@ -254,8 +248,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 3))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
|
||||
def testEmbeddingAttentionSeq2Seq(self):
|
||||
@ -271,8 +264,7 @@ class Seq2SeqTest(tf.test.TestCase):
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertEqual(res[0].shape, (2, 5))
|
||||
|
||||
res = sess.run(mem)
|
||||
self.assertEqual(len(res), 4)
|
||||
res = sess.run([mem])
|
||||
self.assertEqual(res[0].shape, (2, 4))
|
||||
|
||||
# Test externally provided output projection.
|
||||
|
@ -34,12 +34,10 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
|
||||
The simplest form of RNN network generated is:
|
||||
state = cell.zero_state(...)
|
||||
outputs = []
|
||||
states = []
|
||||
for input_ in inputs:
|
||||
output, state = cell(input_, state)
|
||||
outputs.append(output)
|
||||
states.append(state)
|
||||
return (outputs, states)
|
||||
return (outputs, state)
|
||||
|
||||
However, a few other options are available:
|
||||
|
||||
@ -65,9 +63,9 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
|
||||
scope: VariableScope for the created subgraph; defaults to "RNN".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, states) where:
|
||||
A pair (outputs, state) where:
|
||||
outputs is a length T list of outputs (one for each input)
|
||||
states is a length T list of states (one state following each input)
|
||||
state is the final state
|
||||
|
||||
Raises:
|
||||
TypeError: If "cell" is not an instance of RNNCell.
|
||||
@ -82,7 +80,6 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
|
||||
raise ValueError("inputs must not be empty")
|
||||
|
||||
outputs = []
|
||||
states = []
|
||||
with vs.variable_scope(scope or "RNN"):
|
||||
batch_size = array_ops.shape(inputs[0])[0]
|
||||
if initial_state is not None:
|
||||
@ -93,30 +90,25 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
|
||||
state = cell.zero_state(batch_size, dtype)
|
||||
|
||||
if sequence_length: # Prepare variables
|
||||
zero_output_state = (
|
||||
array_ops.zeros(array_ops.pack([batch_size, cell.output_size]),
|
||||
inputs[0].dtype),
|
||||
array_ops.zeros(array_ops.pack([batch_size, cell.state_size]),
|
||||
state.dtype))
|
||||
zero_output = array_ops.zeros(
|
||||
array_ops.pack([batch_size, cell.output_size]), inputs[0].dtype)
|
||||
max_sequence_length = math_ops.reduce_max(sequence_length)
|
||||
|
||||
for time, input_ in enumerate(inputs):
|
||||
if time > 0: vs.get_variable_scope().reuse_variables()
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def output_state():
|
||||
return cell(input_, state)
|
||||
output_state = lambda: cell(input_, state)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
if sequence_length:
|
||||
(output, state) = control_flow_ops.cond(
|
||||
time >= max_sequence_length,
|
||||
lambda: zero_output_state, output_state)
|
||||
lambda: (zero_output, state), output_state)
|
||||
else:
|
||||
(output, state) = output_state()
|
||||
|
||||
outputs.append(output)
|
||||
states.append(state)
|
||||
|
||||
return (outputs, states)
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
def state_saving_rnn(cell, inputs, state_saver, state_name,
|
||||
@ -134,22 +126,22 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
|
||||
scope: VariableScope for the created subgraph; defaults to "RNN".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, states) where:
|
||||
A pair (outputs, state) where:
|
||||
outputs is a length T list of outputs (one for each input)
|
||||
states is a length T list of states (one state following each input)
|
||||
states is the final state
|
||||
|
||||
Raises:
|
||||
TypeError: If "cell" is not an instance of RNNCell.
|
||||
ValueError: If inputs is None or an empty list.
|
||||
"""
|
||||
initial_state = state_saver.state(state_name)
|
||||
(outputs, states) = rnn(cell, inputs, initial_state=initial_state,
|
||||
sequence_length=sequence_length, scope=scope)
|
||||
save_state = state_saver.save_state(state_name, states[-1])
|
||||
(outputs, state) = rnn(cell, inputs, initial_state=initial_state,
|
||||
sequence_length=sequence_length, scope=scope)
|
||||
save_state = state_saver.save_state(state_name, state)
|
||||
with ops.control_dependencies([save_state]):
|
||||
outputs[-1] = array_ops.identity(outputs[-1])
|
||||
|
||||
return (outputs, states)
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
def _reverse_seq(input_seq, lengths):
|
||||
|
@ -54,14 +54,13 @@ def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x cell.output_size] containing generated outputs.
|
||||
states: The state of each cell in each time-step. This is a list with
|
||||
length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
state: The state of each cell at the final time-step.
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
(Note that in some cases, like basic RNN cell or GRU cell, outputs and
|
||||
states can be the same. They are different for LSTM cells though.)
|
||||
"""
|
||||
with vs.variable_scope(scope or "rnn_decoder"):
|
||||
states = [initial_state]
|
||||
state = initial_state
|
||||
outputs = []
|
||||
prev = None
|
||||
for i in xrange(len(decoder_inputs)):
|
||||
@ -72,12 +71,11 @@ def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,
|
||||
inp = array_ops.stop_gradient(loop_function(prev, i))
|
||||
if i > 0:
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
output, new_state = cell(inp, states[-1])
|
||||
output, state = cell(inp, state)
|
||||
outputs.append(output)
|
||||
states.append(new_state)
|
||||
if loop_function is not None:
|
||||
prev = array_ops.stop_gradient(output)
|
||||
return outputs, states
|
||||
return outputs, state
|
||||
|
||||
|
||||
def basic_rnn_seq2seq(
|
||||
@ -98,13 +96,12 @@ def basic_rnn_seq2seq(
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x cell.output_size] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
state: The state of each decoder cell in the final time-step.
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
"""
|
||||
with vs.variable_scope(scope or "basic_rnn_seq2seq"):
|
||||
_, enc_states = rnn.rnn(cell, encoder_inputs, dtype=dtype)
|
||||
return rnn_decoder(decoder_inputs, enc_states[-1], cell)
|
||||
_, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)
|
||||
return rnn_decoder(decoder_inputs, enc_state, cell)
|
||||
|
||||
|
||||
def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
@ -128,16 +125,16 @@ def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x cell.output_size] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
state: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
"""
|
||||
with vs.variable_scope("combined_tied_rnn_seq2seq"):
|
||||
scope = scope or "tied_rnn_seq2seq"
|
||||
_, enc_states = rnn.rnn(
|
||||
_, enc_state = rnn.rnn(
|
||||
cell, encoder_inputs, dtype=dtype, scope=scope)
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
return rnn_decoder(decoder_inputs, enc_states[-1], cell,
|
||||
return rnn_decoder(decoder_inputs, enc_state, cell,
|
||||
loop_function=loop_function, scope=scope)
|
||||
|
||||
|
||||
@ -167,9 +164,9 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x cell.output_size] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
state: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
|
||||
Raises:
|
||||
ValueError: when output_projection has the wrong shape.
|
||||
@ -240,37 +237,37 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x num_decoder_symbols] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
state: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
"""
|
||||
with vs.variable_scope(scope or "embedding_rnn_seq2seq"):
|
||||
# Encoder.
|
||||
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
|
||||
_, encoder_states = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
|
||||
_, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
|
||||
|
||||
# Decoder.
|
||||
if output_projection is None:
|
||||
cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
|
||||
|
||||
if isinstance(feed_previous, bool):
|
||||
return embedding_rnn_decoder(decoder_inputs, encoder_states[-1], cell,
|
||||
return embedding_rnn_decoder(decoder_inputs, encoder_state, cell,
|
||||
num_decoder_symbols, output_projection,
|
||||
feed_previous)
|
||||
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
|
||||
outputs1, states1 = embedding_rnn_decoder(
|
||||
decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
|
||||
outputs1, state1 = embedding_rnn_decoder(
|
||||
decoder_inputs, encoder_state, cell, num_decoder_symbols,
|
||||
output_projection, True)
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
outputs2, states2 = embedding_rnn_decoder(
|
||||
decoder_inputs, encoder_states[-1], cell, num_decoder_symbols,
|
||||
outputs2, state2 = embedding_rnn_decoder(
|
||||
decoder_inputs, encoder_state, cell, num_decoder_symbols,
|
||||
output_projection, False)
|
||||
|
||||
outputs = control_flow_ops.cond(feed_previous,
|
||||
lambda: outputs1, lambda: outputs2)
|
||||
states = control_flow_ops.cond(feed_previous,
|
||||
lambda: states1, lambda: states2)
|
||||
return outputs, states
|
||||
state = control_flow_ops.cond(feed_previous,
|
||||
lambda: state1, lambda: state2)
|
||||
return outputs, state
|
||||
|
||||
|
||||
def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
@ -305,9 +302,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x num_decoder_symbols] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
state: The state of each decoder cell at the final time-step.
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
|
||||
Raises:
|
||||
ValueError: when output_projection has the wrong shape.
|
||||
@ -344,18 +340,18 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell,
|
||||
loop_function=loop_function, dtype=dtype)
|
||||
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
|
||||
outputs1, states1 = tied_rnn_seq2seq(
|
||||
outputs1, state1 = tied_rnn_seq2seq(
|
||||
emb_encoder_inputs, emb_decoder_inputs, cell,
|
||||
loop_function=extract_argmax_and_embed, dtype=dtype)
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
outputs2, states2 = tied_rnn_seq2seq(
|
||||
outputs2, state2 = tied_rnn_seq2seq(
|
||||
emb_encoder_inputs, emb_decoder_inputs, cell, dtype=dtype)
|
||||
|
||||
outputs = control_flow_ops.cond(feed_previous,
|
||||
lambda: outputs1, lambda: outputs2)
|
||||
states = control_flow_ops.cond(feed_previous,
|
||||
lambda: states1, lambda: states2)
|
||||
return outputs, states
|
||||
state = control_flow_ops.cond(feed_previous,
|
||||
lambda: state1, lambda: state2)
|
||||
return outputs, state
|
||||
|
||||
|
||||
def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
|
||||
@ -397,9 +393,8 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
|
||||
new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
|
||||
and then we calculate the output:
|
||||
output = linear(cell_output, new_attn).
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
state: The state of each decoder cell the final time-step.
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
|
||||
Raises:
|
||||
ValueError: when num_heads is not positive, there are no inputs, or shapes
|
||||
@ -431,7 +426,7 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
|
||||
hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
|
||||
v.append(vs.get_variable("AttnV_%d" % a, [attention_vec_size]))
|
||||
|
||||
states = [initial_state]
|
||||
state = initial_state
|
||||
|
||||
def attention(query):
|
||||
"""Put attention masks on hidden using hidden_features and query."""
|
||||
@ -471,14 +466,13 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
|
||||
# Merge input and previous attentions into one vector of the right size.
|
||||
x = rnn_cell.linear([inp] + attns, cell.input_size, True)
|
||||
# Run the RNN.
|
||||
cell_output, new_state = cell(x, states[-1])
|
||||
states.append(new_state)
|
||||
cell_output, state = cell(x, state)
|
||||
# Run the attention mechanism.
|
||||
if i == 0 and initial_state_attention:
|
||||
with vs.variable_scope(vs.get_variable_scope(), reuse=True):
|
||||
attns = attention(new_state)
|
||||
attns = attention(state)
|
||||
else:
|
||||
attns = attention(new_state)
|
||||
attns = attention(state)
|
||||
|
||||
with vs.variable_scope("AttnOutputProjection"):
|
||||
output = rnn_cell.linear([cell_output] + attns, output_size, True)
|
||||
@ -487,7 +481,7 @@ def attention_decoder(decoder_inputs, initial_state, attention_states, cell,
|
||||
prev = array_ops.stop_gradient(output)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs, states
|
||||
return outputs, state
|
||||
|
||||
|
||||
def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
|
||||
@ -526,9 +520,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x output_size] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
state: The state of each decoder cell at the final time-step.
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
|
||||
Raises:
|
||||
ValueError: when output_projection has the wrong shape.
|
||||
@ -607,14 +600,13 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
Returns:
|
||||
outputs: A list of the same length as decoder_inputs of 2D Tensors with
|
||||
shape [batch_size x num_decoder_symbols] containing the generated outputs.
|
||||
states: The state of each decoder cell in each time-step. This is a list
|
||||
with length len(decoder_inputs) -- one item for each time-step.
|
||||
Each item is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
state: The state of each decoder cell at the final time-step.
|
||||
It is a 2D Tensor of shape [batch_size x cell.state_size].
|
||||
"""
|
||||
with vs.variable_scope(scope or "embedding_attention_seq2seq"):
|
||||
# Encoder.
|
||||
encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
|
||||
encoder_outputs, encoder_states = rnn.rnn(
|
||||
encoder_outputs, encoder_state = rnn.rnn(
|
||||
encoder_cell, encoder_inputs, dtype=dtype)
|
||||
|
||||
# First calculate a concatenation of encoder outputs to put attention on.
|
||||
@ -630,25 +622,25 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
|
||||
|
||||
if isinstance(feed_previous, bool):
|
||||
return embedding_attention_decoder(
|
||||
decoder_inputs, encoder_states[-1], attention_states, cell,
|
||||
decoder_inputs, encoder_state, attention_states, cell,
|
||||
num_decoder_symbols, num_heads, output_size, output_projection,
|
||||
feed_previous, initial_state_attention=initial_state_attention)
|
||||
else: # If feed_previous is a Tensor, we construct 2 graphs and use cond.
|
||||
outputs1, states1 = embedding_attention_decoder(
|
||||
decoder_inputs, encoder_states[-1], attention_states, cell,
|
||||
outputs1, state1 = embedding_attention_decoder(
|
||||
decoder_inputs, encoder_state, attention_states, cell,
|
||||
num_decoder_symbols, num_heads, output_size, output_projection, True,
|
||||
initial_state_attention=initial_state_attention)
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
outputs2, states2 = embedding_attention_decoder(
|
||||
decoder_inputs, encoder_states[-1], attention_states, cell,
|
||||
outputs2, state2 = embedding_attention_decoder(
|
||||
decoder_inputs, encoder_state, attention_states, cell,
|
||||
num_decoder_symbols, num_heads, output_size, output_projection, False,
|
||||
initial_state_attention=initial_state_attention)
|
||||
|
||||
outputs = control_flow_ops.cond(feed_previous,
|
||||
lambda: outputs1, lambda: outputs2)
|
||||
states = control_flow_ops.cond(feed_previous,
|
||||
lambda: states1, lambda: states2)
|
||||
return outputs, states
|
||||
state = control_flow_ops.cond(feed_previous,
|
||||
lambda: state1, lambda: state2)
|
||||
return outputs, state
|
||||
|
||||
|
||||
def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
|
||||
|
Loading…
Reference in New Issue
Block a user