Add support for arbitrarily nested tuples for RNN state.
Also fixed a bug in the RNN unit tests. Change: 123150781
This commit is contained in:
parent
3669479261
commit
ae5c66e3c2
@ -76,7 +76,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 3]) # Test GRUCell with input_size != num_units.
|
||||
m = tf.zeros([1, 2])
|
||||
g, _ = tf.nn.rnn_cell.GRUCell(2, input_size=3)(x, m)
|
||||
g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g], {x.name: np.array([[1., 1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
@ -104,7 +104,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
||||
m = tf.zeros([1, 4])
|
||||
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2, input_size=3)(x, m)
|
||||
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
|
||||
m.name: 0.1 * np.ones([1, 4])})
|
||||
@ -147,8 +147,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
x = tf.zeros([batch_size, input_size])
|
||||
m = tf.zeros([batch_size, state_size])
|
||||
output, state = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=num_units, input_size=input_size,
|
||||
num_proj=num_proj, forget_bias=1.0)(x, m)
|
||||
num_units=num_units, num_proj=num_proj, forget_bias=1.0)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([output, state],
|
||||
{x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
|
||||
|
@ -26,9 +26,15 @@ import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
|
||||
def _flatten(list_of_lists):
|
||||
return [x for y in list_of_lists for x in y]
|
||||
# pylint: disable=protected-access
|
||||
_is_sequence = rnn_cell._is_sequence
|
||||
_unpacked_state = rnn_cell._unpacked_state
|
||||
_packed_state = rnn_cell._packed_state
|
||||
# pylint: enable=protected-access
|
||||
|
||||
_flatten = _unpacked_state
|
||||
|
||||
|
||||
class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
|
||||
@ -48,24 +54,32 @@ class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
|
||||
|
||||
class TestStateSaver(object):
|
||||
|
||||
def __init__(self, batch_size, state_size, state_is_tuple=False):
|
||||
def __init__(self, batch_size, state_size):
|
||||
self._batch_size = batch_size
|
||||
self._state_size = state_size
|
||||
self._state_is_tuple = state_is_tuple
|
||||
self.saved_state = {}
|
||||
|
||||
def state(self, _):
|
||||
if self._state_is_tuple:
|
||||
return tuple(
|
||||
tf.zeros(tf.pack([self._batch_size, s])) for s in self._state_size)
|
||||
def state(self, name):
|
||||
if isinstance(self._state_size, dict):
|
||||
return tf.zeros([self._batch_size, self._state_size[name]])
|
||||
else:
|
||||
return tf.zeros(tf.pack([self._batch_size, self._state_size]))
|
||||
return tf.zeros([self._batch_size, self._state_size])
|
||||
|
||||
def save_state(self, name, state):
|
||||
self.saved_state[name] = state
|
||||
return tf.identity(state)
|
||||
|
||||
|
||||
class PackStateTest(tf.test.TestCase):
|
||||
|
||||
def testPackUnpackState(self):
|
||||
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
|
||||
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||
self.assertEqual(_unpacked_state(structure), (3, 4, 5, 6, 7, 9, 10, 8))
|
||||
self.assertEqual(_packed_state(structure, flat),
|
||||
(("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
|
||||
|
||||
|
||||
class RNNTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -197,7 +211,7 @@ class GRUTest(tf.test.TestCase):
|
||||
concat_inputs = tf.placeholder(
|
||||
tf.float32, shape=(time_steps, batch_size, input_size))
|
||||
|
||||
cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size)
|
||||
cell = tf.nn.rnn_cell.GRUCell(num_units=num_units)
|
||||
|
||||
with tf.variable_scope("dynamic_scope"):
|
||||
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
|
||||
@ -229,8 +243,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
max_length = 8
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, initializer=initializer)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
@ -250,8 +263,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
cell_clip=0.0, initializer=initializer)
|
||||
num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
@ -276,7 +288,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=False, initializer=initializer)
|
||||
num_units, use_peepholes=False, initializer=initializer)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
with tf.variable_scope("share_scope"):
|
||||
@ -293,16 +305,16 @@ class LSTMTest(tf.test.TestCase):
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(last_state_value, saved_state_value)
|
||||
|
||||
def _testNoProjNoShardingTupleStateSaver(self, use_gpu):
|
||||
def testNoProjNoShardingTupleStateSaver(self):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
max_length = 8
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, (num_units, num_units))
|
||||
state_saver = TestStateSaver(batch_size, num_units)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=False, initializer=initializer,
|
||||
num_units, use_peepholes=False, initializer=initializer,
|
||||
state_is_tuple=True)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
@ -316,10 +328,70 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.initialize_all_variables().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
last_and_saved_states = sess.run(
|
||||
state + state_saver.saved_state.values(),
|
||||
state + (state_saver.saved_state["c"], state_saver.saved_state["m"]),
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertEqual(4, len(last_and_saved_states))
|
||||
self.assertEqual(last_and_saved_states[:2], last_and_saved_states[2:])
|
||||
self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:])
|
||||
|
||||
def testNoProjNoShardingNestedTupleStateSaver(self):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
max_length = 8
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, {"c0": num_units,
|
||||
"m0": num_units,
|
||||
"c1": num_units + 1,
|
||||
"m1": num_units + 1,
|
||||
"c2": num_units + 2,
|
||||
"m2": num_units + 2,
|
||||
"c3": num_units + 3,
|
||||
"m3": num_units + 3})
|
||||
def _cell(i):
|
||||
return tf.nn.rnn_cell.LSTMCell(
|
||||
num_units + i, use_peepholes=False, initializer=initializer,
|
||||
state_is_tuple=True)
|
||||
|
||||
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell(
|
||||
[_cell(i) for i in range(4)], state_is_tuple=True)
|
||||
|
||||
self.assertEqual(len(cell.state_size), 4)
|
||||
for i in range(4):
|
||||
self.assertEqual(len(cell.state_size[i]), 2)
|
||||
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
|
||||
state_names = (("c0", "m0"), ("c1", "m1"),
|
||||
("c2", "m2"), ("c3", "m3"))
|
||||
with tf.variable_scope("share_scope"):
|
||||
outputs, state = tf.nn.state_saving_rnn(
|
||||
cell, inputs, state_saver=state_saver, state_name=state_names)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
|
||||
# Final output comes from _cell(3) which has state size num_units + 3
|
||||
for out in outputs:
|
||||
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3])
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
last_states = sess.run(
|
||||
list(_unpacked_state(state)), feed_dict={inputs[0]: input_value})
|
||||
saved_states = sess.run(
|
||||
list(state_saver.saved_state.values()),
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertEqual(8, len(last_states))
|
||||
self.assertEqual(8, len(saved_states))
|
||||
flat_state_names = _unpacked_state(state_names)
|
||||
named_saved_states = dict(
|
||||
zip(state_saver.saved_state.keys(), saved_states))
|
||||
|
||||
for i in range(8):
|
||||
self.assertAllEqual(
|
||||
last_states[i],
|
||||
named_saved_states[flat_state_names[i]])
|
||||
|
||||
def _testProjNoSharding(self, use_gpu):
|
||||
num_units = 3
|
||||
@ -332,7 +404,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
@ -353,21 +425,21 @@ class LSTMTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell_notuple = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
cell_tuple = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer, state_is_tuple=True)
|
||||
outputs_notuple, state_notuple = tf.nn.rnn(
|
||||
cell_notuple, inputs, dtype=tf.float32,
|
||||
sequence_length=sequence_length)
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
outputs_tuple, state_is_tuple = tf.nn.rnn(
|
||||
outputs_tuple, state_tuple = tf.nn.rnn(
|
||||
cell_tuple, inputs, dtype=tf.float32,
|
||||
sequence_length=sequence_length)
|
||||
self.assertEqual(len(outputs_notuple), len(inputs))
|
||||
self.assertEqual(len(outputs_tuple), len(inputs))
|
||||
self.assertTrue(isinstance(state_is_tuple, tuple))
|
||||
self.assertTrue(isinstance(state_tuple, tuple))
|
||||
self.assertTrue(isinstance(state_notuple, tf.Tensor))
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
@ -380,9 +452,9 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
(state_notuple_v,) = sess.run(
|
||||
(state_notuple,), feed_dict={inputs[0]: input_value})
|
||||
state_is_tuple_v = sess.run(
|
||||
state_is_tuple, feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(state_notuple_v, np.hstack(state_is_tuple_v))
|
||||
state_tuple_v = sess.run(
|
||||
state_tuple, feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v))
|
||||
|
||||
def _testProjSharding(self, use_gpu):
|
||||
num_units = 3
|
||||
@ -400,7 +472,6 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size=input_size,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
@ -430,7 +501,6 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size=input_size,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
@ -455,7 +525,6 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size=input_size,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
@ -487,7 +556,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer = tf.constant_initializer(0.001)
|
||||
|
||||
cell_noshard = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size,
|
||||
num_units,
|
||||
num_proj=num_proj,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
@ -495,7 +564,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
num_proj_shards=num_proj_shards)
|
||||
|
||||
cell_shard = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
|
||||
with tf.variable_scope("noshard_scope"):
|
||||
@ -541,7 +610,6 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
input_size=input_size,
|
||||
use_peepholes=True,
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
@ -577,10 +645,10 @@ class LSTMTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
cell_d = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer_d)
|
||||
|
||||
with tf.variable_scope("share_scope"):
|
||||
@ -616,7 +684,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
|
||||
with tf.name_scope("scope0"):
|
||||
@ -649,7 +717,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
inputs_c = tf.pack(inputs)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer, state_is_tuple=True)
|
||||
outputs_static, state_static = tf.nn.rnn(
|
||||
cell, inputs, dtype=tf.float32,
|
||||
@ -675,6 +743,61 @@ class LSTMTest(tf.test.TestCase):
|
||||
self.assertAllEqual(
|
||||
np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||
|
||||
def testDynamicRNNWithNestedTupleStates(self):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
num_proj = 4
|
||||
max_length = 8
|
||||
sequence_length = [4, 6]
|
||||
with self.test_session(graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
inputs_c = tf.pack(inputs)
|
||||
def _cell(i):
|
||||
return tf.nn.rnn_cell.LSTMCell(
|
||||
num_units + i, use_peepholes=True,
|
||||
num_proj=num_proj + i, initializer=initializer, state_is_tuple=True)
|
||||
|
||||
# This creates a state tuple which has 4 sub-tuples of length 2 each.
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell(
|
||||
[_cell(i) for i in range(4)], state_is_tuple=True)
|
||||
|
||||
self.assertEqual(len(cell.state_size), 4)
|
||||
for i in range(4):
|
||||
self.assertEqual(len(cell.state_size[i]), 2)
|
||||
|
||||
test_zero = cell.zero_state(1, tf.float32)
|
||||
self.assertEqual(len(test_zero), 4)
|
||||
for i in range(4):
|
||||
self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0])
|
||||
self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
|
||||
|
||||
outputs_static, state_static = tf.nn.rnn(
|
||||
cell, inputs, dtype=tf.float32,
|
||||
sequence_length=sequence_length)
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
|
||||
cell, inputs_c, dtype=tf.float32, time_major=True,
|
||||
sequence_length=sequence_length)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
outputs_static_v = sess.run(
|
||||
outputs_static, feed_dict={inputs[0]: input_value})
|
||||
outputs_dynamic_v = sess.run(
|
||||
outputs_dynamic, feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
|
||||
|
||||
state_static_v = sess.run(
|
||||
_unpacked_state(state_static), feed_dict={inputs[0]: input_value})
|
||||
state_dynamic_v = sess.run(
|
||||
_unpacked_state(state_dynamic), feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(
|
||||
np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||
|
||||
def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
|
||||
time_steps = 8
|
||||
num_units = 3
|
||||
@ -697,7 +820,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
|
||||
with tf.variable_scope("dynamic_scope"):
|
||||
@ -752,7 +875,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, input_size, use_peepholes=True,
|
||||
num_units, use_peepholes=True,
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
|
||||
with tf.variable_scope("dynamic_scope"):
|
||||
@ -1010,8 +1133,7 @@ def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
|
||||
(_, input_size) = inputs_list_t[0].get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, input_size=input_size, use_peepholes=True,
|
||||
initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
outputs, final_state = tf.nn.rnn(
|
||||
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
|
||||
@ -1025,8 +1147,7 @@ def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
|
||||
(unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, input_size=input_size, use_peepholes=True,
|
||||
initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
outputs, final_state = tf.nn.dynamic_rnn(
|
||||
cell, inputs_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
|
||||
@ -1129,8 +1250,7 @@ def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
|
||||
(_, input_size) = inputs_list_t[0].get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, input_size=input_size, use_peepholes=True,
|
||||
initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
outputs, final_state = tf.nn.rnn(
|
||||
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
|
||||
@ -1183,7 +1303,7 @@ def _concat_state_vs_tuple_state_rnn_benchmark(
|
||||
(_, input_size) = inputs_list_t[0].get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, input_size=input_size, use_peepholes=True,
|
||||
num_units=input_size, use_peepholes=True,
|
||||
initializer=initializer, state_is_tuple=state_is_tuple)
|
||||
outputs, final_state = tf.nn.rnn(
|
||||
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
@ -1239,8 +1359,7 @@ def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length,
|
||||
(unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, input_size=input_size, use_peepholes=True,
|
||||
initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
outputs, final_state = tf.nn.dynamic_rnn(
|
||||
cell, inputs_t, sequence_length=sequence_length,
|
||||
swap_memory=swap_memory, dtype=tf.float32)
|
||||
|
@ -32,6 +32,13 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_is_sequence = rnn_cell._is_sequence
|
||||
_unpacked_state = rnn_cell._unpacked_state
|
||||
_packed_state = rnn_cell._packed_state
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def rnn(cell, inputs, initial_state=None, dtype=None,
|
||||
sequence_length=None, scope=None):
|
||||
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
||||
@ -177,20 +184,26 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
|
||||
type of `state_name` does not match that of `cell.state_size`.
|
||||
"""
|
||||
state_size = cell.state_size
|
||||
state_is_tuple = isinstance(state_size, (list, tuple))
|
||||
state_name_tuple = isinstance(state_name, (list, tuple))
|
||||
state_is_tuple = _is_sequence(state_size)
|
||||
state_name_tuple = _is_sequence(state_name)
|
||||
|
||||
if state_is_tuple != state_name_tuple:
|
||||
raise ValueError(
|
||||
"state_name should be a tuple iff cell.state_size is. state_name: %s, "
|
||||
"cell.state_size: %s" % (str(state_name), str(state_size)))
|
||||
"state_name should be the same type as cell.state_size. "
|
||||
"state_name: %s, cell.state_size: %s"
|
||||
% (str(state_name), str(state_size)))
|
||||
|
||||
if state_is_tuple:
|
||||
if len(state_name) != len(state_size):
|
||||
raise ValueError("len(state_name) != len(state_size): %d vs. %d"
|
||||
% (len(state_name), len(state_size)))
|
||||
state_name_flat = _unpacked_state(state_name)
|
||||
state_size_flat = _unpacked_state(state_size)
|
||||
|
||||
initial_state = tuple(state_saver.state(n) for n in state_name)
|
||||
if len(state_name_flat) != len(state_size_flat):
|
||||
raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d"
|
||||
% (len(state_name_flat), len(state_size_flat)))
|
||||
|
||||
initial_state = _packed_state(
|
||||
structure=state_name,
|
||||
state=[state_saver.state(n) for n in state_name_flat])
|
||||
else:
|
||||
initial_state = state_saver.state(state_name)
|
||||
|
||||
@ -198,8 +211,10 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
|
||||
sequence_length=sequence_length, scope=scope)
|
||||
|
||||
if state_is_tuple:
|
||||
state_flat = _unpacked_state(state)
|
||||
save_state = [
|
||||
state_saver.save_state(n, s) for (n, s) in zip(state_name, state)]
|
||||
state_saver.save_state(n, s)
|
||||
for (n, s) in zip(state_name_flat, state_flat)]
|
||||
else:
|
||||
save_state = [state_saver.save_state(state_name, state)]
|
||||
|
||||
@ -262,9 +277,10 @@ def _rnn_step(
|
||||
that returned by `state_size`.
|
||||
"""
|
||||
|
||||
state_is_tuple = isinstance(state, (list, tuple))
|
||||
state_is_tuple = _is_sequence(state)
|
||||
orig_state = state
|
||||
# Convert state to a list for ease of use
|
||||
state = list(state) if state_is_tuple else [state]
|
||||
state = list(_unpacked_state(state)) if state_is_tuple else [state]
|
||||
state_shape = [s.get_shape() for s in state]
|
||||
|
||||
def _copy_some_through(new_output, new_state):
|
||||
@ -279,7 +295,8 @@ def _rnn_step(
|
||||
def _maybe_copy_some_through():
|
||||
"""Run RNN step. Pass through either no or some past state."""
|
||||
new_output, new_state = call_cell()
|
||||
new_state = list(new_state) if state_is_tuple else [new_state]
|
||||
new_state = (
|
||||
list(_unpacked_state(new_state)) if state_is_tuple else [new_state])
|
||||
|
||||
if len(state) != len(new_state):
|
||||
raise ValueError(
|
||||
@ -300,7 +317,8 @@ def _rnn_step(
|
||||
# steps. This is faster when max_seq_len is equal to the number of unrolls
|
||||
# (which is typical for dynamic_rnn).
|
||||
new_output, new_state = call_cell()
|
||||
new_state = list(new_state) if state_is_tuple else [new_state]
|
||||
new_state = (
|
||||
list(_unpacked_state(new_state)) if state_is_tuple else [new_state])
|
||||
|
||||
if len(state) != len(new_state):
|
||||
raise ValueError(
|
||||
@ -325,7 +343,9 @@ def _rnn_step(
|
||||
final_state_i.set_shape(state_shape_i)
|
||||
|
||||
if state_is_tuple:
|
||||
return (final_output, tuple(final_state))
|
||||
return (
|
||||
final_output,
|
||||
_packed_state(structure=orig_state, state=final_state))
|
||||
else:
|
||||
return (final_output, final_state[0])
|
||||
|
||||
@ -613,9 +633,9 @@ def _dynamic_rnn_loop(
|
||||
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
|
||||
|
||||
state_size = cell.state_size
|
||||
state_is_tuple = isinstance(state_size, (list, tuple))
|
||||
state_is_tuple = _is_sequence(state_size)
|
||||
|
||||
state = tuple(state) if state_is_tuple else (state,)
|
||||
state = _unpacked_state(state) if state_is_tuple else (state,)
|
||||
|
||||
with ops.op_scope([], "dynamic_rnn") as scope:
|
||||
base_name = scope
|
||||
@ -646,8 +666,9 @@ def _dynamic_rnn_loop(
|
||||
# Restore some shape information
|
||||
input_t.set_shape([const_batch_size, const_depth])
|
||||
|
||||
# Unpack state if not using state tuples
|
||||
state = tuple(state) if state_is_tuple else state[0]
|
||||
# Pack state back up for use by cell
|
||||
state = (_packed_state(structure=state_size, state=state)
|
||||
if state_is_tuple else state[0])
|
||||
|
||||
call_cell = lambda: cell(input_t, state)
|
||||
|
||||
@ -665,7 +686,7 @@ def _dynamic_rnn_loop(
|
||||
(output, new_state) = call_cell()
|
||||
|
||||
# Pack state if using state tuples
|
||||
new_state = tuple(new_state) if state_is_tuple else (new_state,)
|
||||
new_state = _unpacked_state(new_state) if state_is_tuple else (new_state,)
|
||||
|
||||
output_ta_t = output_ta_t.write(time, output)
|
||||
|
||||
@ -686,6 +707,7 @@ def _dynamic_rnn_loop(
|
||||
const_time_steps, const_batch_size, cell.output_size])
|
||||
|
||||
# Unpack final state if not using state tuples.
|
||||
final_state = tuple(final_state) if state_is_tuple else final_state[0]
|
||||
final_state = (
|
||||
_unpacked_state(final_state) if state_is_tuple else final_state[0])
|
||||
|
||||
return (final_outputs, final_state)
|
||||
|
@ -18,11 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
# pylint: disable=redefined-builtin,unused-import
|
||||
from six.moves import xrange
|
||||
# pylint: enable=redefined-builtin,unused-import
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -39,6 +38,88 @@ from tensorflow.python.ops.math_ops import tanh
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def _is_sequence(seq):
|
||||
return (isinstance(seq, collections.Sequence)
|
||||
and not isinstance(seq, six.string_types))
|
||||
|
||||
|
||||
def _packed_state_with_indices(structure, flat, index):
|
||||
"""Helper function for _packed_state.
|
||||
|
||||
Args:
|
||||
structure: Substructure (tuple of elements and/or tuples) to mimic
|
||||
flat: Flattened values to output substructure for.
|
||||
index: Index at which to start reading from flat.
|
||||
|
||||
Returns:
|
||||
The tuple (new_index, child), where:
|
||||
* new_index - the updated index into `flat` having processed `structure`.
|
||||
* packed - the subset of `flat` corresponding to `structure`,
|
||||
having started at `index`, and packed into the same nested
|
||||
format.
|
||||
|
||||
Raises:
|
||||
ValueError: if `structure` contains more elements than `flat`
|
||||
(assuming indexing starts from `index`).
|
||||
"""
|
||||
packed = []
|
||||
for s in structure:
|
||||
if _is_sequence(s):
|
||||
new_index, child = _packed_state_with_indices(s, flat, index)
|
||||
packed.append(type(s)(child))
|
||||
index = new_index
|
||||
else:
|
||||
packed.append(flat[index])
|
||||
index += 1
|
||||
return (index, packed)
|
||||
|
||||
|
||||
def _yield_unpacked_state(state):
|
||||
for s in state:
|
||||
if _is_sequence(s):
|
||||
for si in _yield_unpacked_state(s):
|
||||
yield si
|
||||
else:
|
||||
yield s
|
||||
|
||||
|
||||
def _unpacked_state(state):
|
||||
if not _is_sequence(state):
|
||||
raise TypeError("state must be a sequence")
|
||||
return type(state)(_yield_unpacked_state(state))
|
||||
|
||||
|
||||
def _packed_state(structure, state):
|
||||
"""Returns the flat state packed into a recursive tuple like structure.
|
||||
|
||||
Args:
|
||||
structure: tuple or list constructed of scalars and/or other tuples/lists.
|
||||
state: flattened state.
|
||||
|
||||
Returns:
|
||||
packed: `state` converted to have the same recursive structure as
|
||||
`structure`.
|
||||
|
||||
Raises:
|
||||
TypeError: If structure or state is not a tuple or list.
|
||||
ValueError: If state and structure have different element counts.
|
||||
"""
|
||||
if not _is_sequence(structure):
|
||||
raise TypeError("structure must be a sequence")
|
||||
if not _is_sequence(state):
|
||||
raise TypeError("state must be a sequence")
|
||||
|
||||
flat_structure = _unpacked_state(structure)
|
||||
if len(flat_structure) != len(state):
|
||||
raise ValueError(
|
||||
"Internal error: Could not pack state. Structure had %d elements, but "
|
||||
"state had %d elements. Structure: %s, state: %s."
|
||||
% (len(flat_structure), len(state), structure, state))
|
||||
|
||||
(_, packed) = _packed_state_with_indices(structure, state, 0)
|
||||
return type(structure)(packed)
|
||||
|
||||
|
||||
class RNNCell(object):
|
||||
"""Abstract object representing an RNN cell.
|
||||
|
||||
@ -98,17 +179,19 @@ class RNNCell(object):
|
||||
If `state_size` is an int, then the return value is a `2-D` tensor of
|
||||
shape `[batch_size x state_size]` filled with zeros.
|
||||
|
||||
If `state_size` is a list or tuple of ints, then the return value is
|
||||
a tuple of `2-D` tensors with shape
|
||||
`[batch_size x s] for s in state_size`.
|
||||
If `state_size` is a nested list or tuple, then the return value is
|
||||
a nested list or tuple (of the same structure) of `2-D` tensors with
|
||||
the shapes `[batch_size x s]` for each s in `state_size`.
|
||||
"""
|
||||
state_size = self.state_size
|
||||
if isinstance(state_size, (list, tuple)):
|
||||
zeros = tuple(
|
||||
if _is_sequence(state_size):
|
||||
state_size_flat = _unpacked_state(state_size)
|
||||
zeros_flat = [
|
||||
array_ops.zeros(array_ops.pack([batch_size, s]), dtype=dtype)
|
||||
for s in state_size)
|
||||
for s, z in zip(state_size, zeros):
|
||||
for s in state_size_flat]
|
||||
for s, z in zip(state_size_flat, zeros_flat):
|
||||
z.set_shape([None, s])
|
||||
zeros = _packed_state(structure=state_size, state=zeros_flat)
|
||||
else:
|
||||
zeros = array_ops.zeros(
|
||||
array_ops.pack([batch_size, state_size]), dtype=dtype)
|
||||
@ -675,7 +758,7 @@ class MultiRNNCell(RNNCell):
|
||||
self._cells = cells
|
||||
self._state_is_tuple = state_is_tuple
|
||||
if not state_is_tuple:
|
||||
if any(isinstance(c.state_size, (list, tuple)) for c in self._cells):
|
||||
if any(_is_sequence(c.state_size) for c in self._cells):
|
||||
raise ValueError("Some cells return tuples of states, but the flag "
|
||||
"state_is_tuple is not set. State sizes are: %s"
|
||||
% str([c.state_size for c in self._cells]))
|
||||
@ -700,7 +783,7 @@ class MultiRNNCell(RNNCell):
|
||||
for i, cell in enumerate(self._cells):
|
||||
with vs.variable_scope("Cell%d" % i):
|
||||
if self._state_is_tuple:
|
||||
if not isinstance(state, (list, tuple)):
|
||||
if not _is_sequence(state):
|
||||
raise ValueError(
|
||||
"Expected state to be a tuple of length %d, but received: %s"
|
||||
% (len(self.state_size), state))
|
||||
@ -778,9 +861,9 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
|
||||
Raises:
|
||||
ValueError: if some of the arguments has unspecified or wrong shape.
|
||||
"""
|
||||
if args is None or (isinstance(args, (list, tuple)) and not args):
|
||||
if args is None or (_is_sequence(args) and not args):
|
||||
raise ValueError("`args` must be specified")
|
||||
if not isinstance(args, (list, tuple)):
|
||||
if not _is_sequence(args):
|
||||
args = [args]
|
||||
|
||||
# Calculate the total size of arguments on dimension 1.
|
||||
|
Loading…
Reference in New Issue
Block a user