diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index e3756e03d25..10a1a2e2a39 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -76,7 +76,7 @@ class RNNCellTest(tf.test.TestCase): with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 3]) # Test GRUCell with input_size != num_units. m = tf.zeros([1, 2]) - g, _ = tf.nn.rnn_cell.GRUCell(2, input_size=3)(x, m) + g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m) sess.run([tf.initialize_all_variables()]) res = sess.run([g], {x.name: np.array([[1., 1., 1.]]), m.name: np.array([[0.1, 0.1]])}) @@ -104,7 +104,7 @@ class RNNCellTest(tf.test.TestCase): with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 3]) # Test BasicLSTMCell with input_size != num_units. m = tf.zeros([1, 4]) - g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2, input_size=3)(x, m) + g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m) sess.run([tf.initialize_all_variables()]) res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]), m.name: 0.1 * np.ones([1, 4])}) @@ -147,8 +147,7 @@ class RNNCellTest(tf.test.TestCase): x = tf.zeros([batch_size, input_size]) m = tf.zeros([batch_size, state_size]) output, state = tf.nn.rnn_cell.LSTMCell( - num_units=num_units, input_size=input_size, - num_proj=num_proj, forget_bias=1.0)(x, m) + num_units=num_units, num_proj=num_proj, forget_bias=1.0)(x, m) sess.run([tf.initialize_all_variables()]) res = sess.run([output, state], {x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 646c981791d..469635ae4f8 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -26,9 +26,15 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.python.ops import rnn_cell -def _flatten(list_of_lists): - return [x for y in list_of_lists for x in y] +# pylint: disable=protected-access +_is_sequence = rnn_cell._is_sequence +_unpacked_state = rnn_cell._unpacked_state +_packed_state = rnn_cell._packed_state +# pylint: enable=protected-access + +_flatten = _unpacked_state class Plus1RNNCell(tf.nn.rnn_cell.RNNCell): @@ -48,24 +54,32 @@ class Plus1RNNCell(tf.nn.rnn_cell.RNNCell): class TestStateSaver(object): - def __init__(self, batch_size, state_size, state_is_tuple=False): + def __init__(self, batch_size, state_size): self._batch_size = batch_size self._state_size = state_size - self._state_is_tuple = state_is_tuple self.saved_state = {} - def state(self, _): - if self._state_is_tuple: - return tuple( - tf.zeros(tf.pack([self._batch_size, s])) for s in self._state_size) + def state(self, name): + if isinstance(self._state_size, dict): + return tf.zeros([self._batch_size, self._state_size[name]]) else: - return tf.zeros(tf.pack([self._batch_size, self._state_size])) + return tf.zeros([self._batch_size, self._state_size]) def save_state(self, name, state): self.saved_state[name] = state return tf.identity(state) +class PackStateTest(tf.test.TestCase): + + def testPackUnpackState(self): + structure = ((3, 4), 5, (6, 7, (9, 10), 8)) + flat = ["a", "b", "c", "d", "e", "f", "g", "h"] + self.assertEqual(_unpacked_state(structure), (3, 4, 5, 6, 7, 9, 10, 8)) + self.assertEqual(_packed_state(structure, flat), + (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) + + class RNNTest(tf.test.TestCase): def setUp(self): @@ -197,7 +211,7 @@ class GRUTest(tf.test.TestCase): concat_inputs = tf.placeholder( tf.float32, shape=(time_steps, batch_size, input_size)) - cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size) + cell = tf.nn.rnn_cell.GRUCell(num_units=num_units) with tf.variable_scope("dynamic_scope"): outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( @@ -229,8 +243,7 @@ class LSTMTest(tf.test.TestCase): max_length = 8 with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) - cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, initializer=initializer) + cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer) inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) @@ -250,8 +263,7 @@ class LSTMTest(tf.test.TestCase): with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, - cell_clip=0.0, initializer=initializer) + num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer) inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) @@ -276,7 +288,7 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) state_saver = TestStateSaver(batch_size, 2 * num_units) cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=False, initializer=initializer) + num_units, use_peepholes=False, initializer=initializer) inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] with tf.variable_scope("share_scope"): @@ -293,16 +305,16 @@ class LSTMTest(tf.test.TestCase): feed_dict={inputs[0]: input_value}) self.assertAllEqual(last_state_value, saved_state_value) - def _testNoProjNoShardingTupleStateSaver(self, use_gpu): + def testNoProjNoShardingTupleStateSaver(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + with self.test_session(graph=tf.Graph()) as sess: initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) - state_saver = TestStateSaver(batch_size, (num_units, num_units)) + state_saver = TestStateSaver(batch_size, num_units) cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=False, initializer=initializer, + num_units, use_peepholes=False, initializer=initializer, state_is_tuple=True) inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size))] @@ -316,10 +328,70 @@ class LSTMTest(tf.test.TestCase): tf.initialize_all_variables().run() input_value = np.random.randn(batch_size, input_size) last_and_saved_states = sess.run( - state + state_saver.saved_state.values(), + state + (state_saver.saved_state["c"], state_saver.saved_state["m"]), feed_dict={inputs[0]: input_value}) self.assertEqual(4, len(last_and_saved_states)) - self.assertEqual(last_and_saved_states[:2], last_and_saved_states[2:]) + self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:]) + + def testNoProjNoShardingNestedTupleStateSaver(self): + num_units = 3 + input_size = 5 + batch_size = 2 + max_length = 8 + with self.test_session(graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + state_saver = TestStateSaver(batch_size, {"c0": num_units, + "m0": num_units, + "c1": num_units + 1, + "m1": num_units + 1, + "c2": num_units + 2, + "m2": num_units + 2, + "c3": num_units + 3, + "m3": num_units + 3}) + def _cell(i): + return tf.nn.rnn_cell.LSTMCell( + num_units + i, use_peepholes=False, initializer=initializer, + state_is_tuple=True) + + # This creates a state tuple which has 4 sub-tuples of length 2 each. + cell = tf.nn.rnn_cell.MultiRNNCell( + [_cell(i) for i in range(4)], state_is_tuple=True) + + self.assertEqual(len(cell.state_size), 4) + for i in range(4): + self.assertEqual(len(cell.state_size[i]), 2) + + inputs = max_length * [ + tf.placeholder(tf.float32, shape=(batch_size, input_size))] + + state_names = (("c0", "m0"), ("c1", "m1"), + ("c2", "m2"), ("c3", "m3")) + with tf.variable_scope("share_scope"): + outputs, state = tf.nn.state_saving_rnn( + cell, inputs, state_saver=state_saver, state_name=state_names) + self.assertEqual(len(outputs), len(inputs)) + + # Final output comes from _cell(3) which has state size num_units + 3 + for out in outputs: + self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3]) + + tf.initialize_all_variables().run() + input_value = np.random.randn(batch_size, input_size) + last_states = sess.run( + list(_unpacked_state(state)), feed_dict={inputs[0]: input_value}) + saved_states = sess.run( + list(state_saver.saved_state.values()), + feed_dict={inputs[0]: input_value}) + self.assertEqual(8, len(last_states)) + self.assertEqual(8, len(saved_states)) + flat_state_names = _unpacked_state(state_names) + named_saved_states = dict( + zip(state_saver.saved_state.keys(), saved_states)) + + for i in range(8): + self.assertAllEqual( + last_states[i], + named_saved_states[flat_state_names[i]]) def _testProjNoSharding(self, use_gpu): num_units = 3 @@ -332,7 +404,7 @@ class LSTMTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(None, input_size))] cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer) outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32) self.assertEqual(len(outputs), len(inputs)) @@ -353,21 +425,21 @@ class LSTMTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(None, input_size))] cell_notuple = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer) cell_tuple = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer, state_is_tuple=True) outputs_notuple, state_notuple = tf.nn.rnn( cell_notuple, inputs, dtype=tf.float32, sequence_length=sequence_length) tf.get_variable_scope().reuse_variables() - outputs_tuple, state_is_tuple = tf.nn.rnn( + outputs_tuple, state_tuple = tf.nn.rnn( cell_tuple, inputs, dtype=tf.float32, sequence_length=sequence_length) self.assertEqual(len(outputs_notuple), len(inputs)) self.assertEqual(len(outputs_tuple), len(inputs)) - self.assertTrue(isinstance(state_is_tuple, tuple)) + self.assertTrue(isinstance(state_tuple, tuple)) self.assertTrue(isinstance(state_notuple, tf.Tensor)) tf.initialize_all_variables().run() @@ -380,9 +452,9 @@ class LSTMTest(tf.test.TestCase): (state_notuple_v,) = sess.run( (state_notuple,), feed_dict={inputs[0]: input_value}) - state_is_tuple_v = sess.run( - state_is_tuple, feed_dict={inputs[0]: input_value}) - self.assertAllEqual(state_notuple_v, np.hstack(state_is_tuple_v)) + state_tuple_v = sess.run( + state_tuple, feed_dict={inputs[0]: input_value}) + self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) def _testProjSharding(self, use_gpu): num_units = 3 @@ -400,7 +472,6 @@ class LSTMTest(tf.test.TestCase): cell = tf.nn.rnn_cell.LSTMCell( num_units, - input_size=input_size, use_peepholes=True, num_proj=num_proj, num_unit_shards=num_unit_shards, @@ -430,7 +501,6 @@ class LSTMTest(tf.test.TestCase): cell = tf.nn.rnn_cell.LSTMCell( num_units, - input_size=input_size, use_peepholes=True, num_proj=num_proj, num_unit_shards=num_unit_shards, @@ -455,7 +525,6 @@ class LSTMTest(tf.test.TestCase): cell = tf.nn.rnn_cell.LSTMCell( num_units, - input_size=input_size, use_peepholes=True, num_proj=num_proj, num_unit_shards=num_unit_shards, @@ -487,7 +556,7 @@ class LSTMTest(tf.test.TestCase): initializer = tf.constant_initializer(0.001) cell_noshard = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, + num_units, num_proj=num_proj, use_peepholes=True, initializer=initializer, @@ -495,7 +564,7 @@ class LSTMTest(tf.test.TestCase): num_proj_shards=num_proj_shards) cell_shard = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, initializer=initializer, num_proj=num_proj) with tf.variable_scope("noshard_scope"): @@ -541,7 +610,6 @@ class LSTMTest(tf.test.TestCase): cell = tf.nn.rnn_cell.LSTMCell( num_units, - input_size=input_size, use_peepholes=True, num_proj=num_proj, num_unit_shards=num_unit_shards, @@ -577,10 +645,10 @@ class LSTMTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(None, input_size))] cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer) cell_d = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer_d) with tf.variable_scope("share_scope"): @@ -616,7 +684,7 @@ class LSTMTest(tf.test.TestCase): inputs = max_length * [ tf.placeholder(tf.float32, shape=(None, input_size))] cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer) with tf.name_scope("scope0"): @@ -649,7 +717,7 @@ class LSTMTest(tf.test.TestCase): tf.placeholder(tf.float32, shape=(None, input_size))] inputs_c = tf.pack(inputs) cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer, state_is_tuple=True) outputs_static, state_static = tf.nn.rnn( cell, inputs, dtype=tf.float32, @@ -675,6 +743,61 @@ class LSTMTest(tf.test.TestCase): self.assertAllEqual( np.hstack(state_static_v), np.hstack(state_dynamic_v)) + def testDynamicRNNWithNestedTupleStates(self): + num_units = 3 + input_size = 5 + batch_size = 2 + num_proj = 4 + max_length = 8 + sequence_length = [4, 6] + with self.test_session(graph=tf.Graph()) as sess: + initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) + inputs = max_length * [ + tf.placeholder(tf.float32, shape=(None, input_size))] + inputs_c = tf.pack(inputs) + def _cell(i): + return tf.nn.rnn_cell.LSTMCell( + num_units + i, use_peepholes=True, + num_proj=num_proj + i, initializer=initializer, state_is_tuple=True) + + # This creates a state tuple which has 4 sub-tuples of length 2 each. + cell = tf.nn.rnn_cell.MultiRNNCell( + [_cell(i) for i in range(4)], state_is_tuple=True) + + self.assertEqual(len(cell.state_size), 4) + for i in range(4): + self.assertEqual(len(cell.state_size[i]), 2) + + test_zero = cell.zero_state(1, tf.float32) + self.assertEqual(len(test_zero), 4) + for i in range(4): + self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0]) + self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1]) + + outputs_static, state_static = tf.nn.rnn( + cell, inputs, dtype=tf.float32, + sequence_length=sequence_length) + tf.get_variable_scope().reuse_variables() + outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( + cell, inputs_c, dtype=tf.float32, time_major=True, + sequence_length=sequence_length) + + tf.initialize_all_variables().run() + + input_value = np.random.randn(batch_size, input_size) + outputs_static_v = sess.run( + outputs_static, feed_dict={inputs[0]: input_value}) + outputs_dynamic_v = sess.run( + outputs_dynamic, feed_dict={inputs[0]: input_value}) + self.assertAllEqual(outputs_static_v, outputs_dynamic_v) + + state_static_v = sess.run( + _unpacked_state(state_static), feed_dict={inputs[0]: input_value}) + state_dynamic_v = sess.run( + _unpacked_state(state_dynamic), feed_dict={inputs[0]: input_value}) + self.assertAllEqual( + np.hstack(state_static_v), np.hstack(state_dynamic_v)) + def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): time_steps = 8 num_units = 3 @@ -697,7 +820,7 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, initializer=initializer, num_proj=num_proj) with tf.variable_scope("dynamic_scope"): @@ -752,7 +875,7 @@ class LSTMTest(tf.test.TestCase): initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) cell = tf.nn.rnn_cell.LSTMCell( - num_units, input_size, use_peepholes=True, + num_units, use_peepholes=True, initializer=initializer, num_proj=num_proj) with tf.variable_scope("dynamic_scope"): @@ -1010,8 +1133,7 @@ def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length): (_, input_size) = inputs_list_t[0].get_shape().as_list() initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127) cell = tf.nn.rnn_cell.LSTMCell( - num_units=input_size, input_size=input_size, use_peepholes=True, - initializer=initializer) + num_units=input_size, use_peepholes=True, initializer=initializer) outputs, final_state = tf.nn.rnn( cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32) @@ -1025,8 +1147,7 @@ def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length): (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list() initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127) cell = tf.nn.rnn_cell.LSTMCell( - num_units=input_size, input_size=input_size, use_peepholes=True, - initializer=initializer) + num_units=input_size, use_peepholes=True, initializer=initializer) outputs, final_state = tf.nn.dynamic_rnn( cell, inputs_t, sequence_length=sequence_length, dtype=tf.float32) @@ -1129,8 +1250,7 @@ def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length): (_, input_size) = inputs_list_t[0].get_shape().as_list() initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127) cell = tf.nn.rnn_cell.LSTMCell( - num_units=input_size, input_size=input_size, use_peepholes=True, - initializer=initializer) + num_units=input_size, use_peepholes=True, initializer=initializer) outputs, final_state = tf.nn.rnn( cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32) @@ -1183,7 +1303,7 @@ def _concat_state_vs_tuple_state_rnn_benchmark( (_, input_size) = inputs_list_t[0].get_shape().as_list() initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127) cell = tf.nn.rnn_cell.LSTMCell( - num_units=input_size, input_size=input_size, use_peepholes=True, + num_units=input_size, use_peepholes=True, initializer=initializer, state_is_tuple=state_is_tuple) outputs, final_state = tf.nn.rnn( cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32) @@ -1239,8 +1359,7 @@ def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length, (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list() initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127) cell = tf.nn.rnn_cell.LSTMCell( - num_units=input_size, input_size=input_size, use_peepholes=True, - initializer=initializer) + num_units=input_size, use_peepholes=True, initializer=initializer) outputs, final_state = tf.nn.dynamic_rnn( cell, inputs_t, sequence_length=sequence_length, swap_memory=swap_memory, dtype=tf.float32) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 259811c2a4f..6d9a0d4e3f2 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -32,6 +32,13 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs +# pylint: disable=protected-access +_is_sequence = rnn_cell._is_sequence +_unpacked_state = rnn_cell._unpacked_state +_packed_state = rnn_cell._packed_state +# pylint: enable=protected-access + + def rnn(cell, inputs, initial_state=None, dtype=None, sequence_length=None, scope=None): """Creates a recurrent neural network specified by RNNCell `cell`. @@ -177,20 +184,26 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, type of `state_name` does not match that of `cell.state_size`. """ state_size = cell.state_size - state_is_tuple = isinstance(state_size, (list, tuple)) - state_name_tuple = isinstance(state_name, (list, tuple)) + state_is_tuple = _is_sequence(state_size) + state_name_tuple = _is_sequence(state_name) if state_is_tuple != state_name_tuple: raise ValueError( - "state_name should be a tuple iff cell.state_size is. state_name: %s, " - "cell.state_size: %s" % (str(state_name), str(state_size))) + "state_name should be the same type as cell.state_size. " + "state_name: %s, cell.state_size: %s" + % (str(state_name), str(state_size))) if state_is_tuple: - if len(state_name) != len(state_size): - raise ValueError("len(state_name) != len(state_size): %d vs. %d" - % (len(state_name), len(state_size))) + state_name_flat = _unpacked_state(state_name) + state_size_flat = _unpacked_state(state_size) - initial_state = tuple(state_saver.state(n) for n in state_name) + if len(state_name_flat) != len(state_size_flat): + raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" + % (len(state_name_flat), len(state_size_flat))) + + initial_state = _packed_state( + structure=state_name, + state=[state_saver.state(n) for n in state_name_flat]) else: initial_state = state_saver.state(state_name) @@ -198,8 +211,10 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, sequence_length=sequence_length, scope=scope) if state_is_tuple: + state_flat = _unpacked_state(state) save_state = [ - state_saver.save_state(n, s) for (n, s) in zip(state_name, state)] + state_saver.save_state(n, s) + for (n, s) in zip(state_name_flat, state_flat)] else: save_state = [state_saver.save_state(state_name, state)] @@ -262,9 +277,10 @@ def _rnn_step( that returned by `state_size`. """ - state_is_tuple = isinstance(state, (list, tuple)) + state_is_tuple = _is_sequence(state) + orig_state = state # Convert state to a list for ease of use - state = list(state) if state_is_tuple else [state] + state = list(_unpacked_state(state)) if state_is_tuple else [state] state_shape = [s.get_shape() for s in state] def _copy_some_through(new_output, new_state): @@ -279,7 +295,8 @@ def _rnn_step( def _maybe_copy_some_through(): """Run RNN step. Pass through either no or some past state.""" new_output, new_state = call_cell() - new_state = list(new_state) if state_is_tuple else [new_state] + new_state = ( + list(_unpacked_state(new_state)) if state_is_tuple else [new_state]) if len(state) != len(new_state): raise ValueError( @@ -300,7 +317,8 @@ def _rnn_step( # steps. This is faster when max_seq_len is equal to the number of unrolls # (which is typical for dynamic_rnn). new_output, new_state = call_cell() - new_state = list(new_state) if state_is_tuple else [new_state] + new_state = ( + list(_unpacked_state(new_state)) if state_is_tuple else [new_state]) if len(state) != len(new_state): raise ValueError( @@ -325,7 +343,9 @@ def _rnn_step( final_state_i.set_shape(state_shape_i) if state_is_tuple: - return (final_output, tuple(final_state)) + return ( + final_output, + _packed_state(structure=orig_state, state=final_state)) else: return (final_output, final_state[0]) @@ -613,9 +633,9 @@ def _dynamic_rnn_loop( time = array_ops.constant(0, dtype=dtypes.int32, name="time") state_size = cell.state_size - state_is_tuple = isinstance(state_size, (list, tuple)) + state_is_tuple = _is_sequence(state_size) - state = tuple(state) if state_is_tuple else (state,) + state = _unpacked_state(state) if state_is_tuple else (state,) with ops.op_scope([], "dynamic_rnn") as scope: base_name = scope @@ -646,8 +666,9 @@ def _dynamic_rnn_loop( # Restore some shape information input_t.set_shape([const_batch_size, const_depth]) - # Unpack state if not using state tuples - state = tuple(state) if state_is_tuple else state[0] + # Pack state back up for use by cell + state = (_packed_state(structure=state_size, state=state) + if state_is_tuple else state[0]) call_cell = lambda: cell(input_t, state) @@ -665,7 +686,7 @@ def _dynamic_rnn_loop( (output, new_state) = call_cell() # Pack state if using state tuples - new_state = tuple(new_state) if state_is_tuple else (new_state,) + new_state = _unpacked_state(new_state) if state_is_tuple else (new_state,) output_ta_t = output_ta_t.write(time, output) @@ -686,6 +707,7 @@ def _dynamic_rnn_loop( const_time_steps, const_batch_size, cell.output_size]) # Unpack final state if not using state tuples. - final_state = tuple(final_state) if state_is_tuple else final_state[0] + final_state = ( + _unpacked_state(final_state) if state_is_tuple else final_state[0]) return (final_outputs, final_state) diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index bfd0758883b..69ff7775d52 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import math -# pylint: disable=redefined-builtin,unused-import -from six.moves import xrange -# pylint: enable=redefined-builtin,unused-import +import six from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -39,6 +38,88 @@ from tensorflow.python.ops.math_ops import tanh from tensorflow.python.platform import tf_logging as logging +def _is_sequence(seq): + return (isinstance(seq, collections.Sequence) + and not isinstance(seq, six.string_types)) + + +def _packed_state_with_indices(structure, flat, index): + """Helper function for _packed_state. + + Args: + structure: Substructure (tuple of elements and/or tuples) to mimic + flat: Flattened values to output substructure for. + index: Index at which to start reading from flat. + + Returns: + The tuple (new_index, child), where: + * new_index - the updated index into `flat` having processed `structure`. + * packed - the subset of `flat` corresponding to `structure`, + having started at `index`, and packed into the same nested + format. + + Raises: + ValueError: if `structure` contains more elements than `flat` + (assuming indexing starts from `index`). + """ + packed = [] + for s in structure: + if _is_sequence(s): + new_index, child = _packed_state_with_indices(s, flat, index) + packed.append(type(s)(child)) + index = new_index + else: + packed.append(flat[index]) + index += 1 + return (index, packed) + + +def _yield_unpacked_state(state): + for s in state: + if _is_sequence(s): + for si in _yield_unpacked_state(s): + yield si + else: + yield s + + +def _unpacked_state(state): + if not _is_sequence(state): + raise TypeError("state must be a sequence") + return type(state)(_yield_unpacked_state(state)) + + +def _packed_state(structure, state): + """Returns the flat state packed into a recursive tuple like structure. + + Args: + structure: tuple or list constructed of scalars and/or other tuples/lists. + state: flattened state. + + Returns: + packed: `state` converted to have the same recursive structure as + `structure`. + + Raises: + TypeError: If structure or state is not a tuple or list. + ValueError: If state and structure have different element counts. + """ + if not _is_sequence(structure): + raise TypeError("structure must be a sequence") + if not _is_sequence(state): + raise TypeError("state must be a sequence") + + flat_structure = _unpacked_state(structure) + if len(flat_structure) != len(state): + raise ValueError( + "Internal error: Could not pack state. Structure had %d elements, but " + "state had %d elements. Structure: %s, state: %s." + % (len(flat_structure), len(state), structure, state)) + + (_, packed) = _packed_state_with_indices(structure, state, 0) + return type(structure)(packed) + + class RNNCell(object): """Abstract object representing an RNN cell. @@ -98,17 +179,19 @@ class RNNCell(object): If `state_size` is an int, then the return value is a `2-D` tensor of shape `[batch_size x state_size]` filled with zeros. - If `state_size` is a list or tuple of ints, then the return value is - a tuple of `2-D` tensors with shape - `[batch_size x s] for s in state_size`. + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with + the shapes `[batch_size x s]` for each s in `state_size`. """ state_size = self.state_size - if isinstance(state_size, (list, tuple)): - zeros = tuple( + if _is_sequence(state_size): + state_size_flat = _unpacked_state(state_size) + zeros_flat = [ array_ops.zeros(array_ops.pack([batch_size, s]), dtype=dtype) - for s in state_size) - for s, z in zip(state_size, zeros): + for s in state_size_flat] + for s, z in zip(state_size_flat, zeros_flat): z.set_shape([None, s]) + zeros = _packed_state(structure=state_size, state=zeros_flat) else: zeros = array_ops.zeros( array_ops.pack([batch_size, state_size]), dtype=dtype) @@ -675,7 +758,7 @@ class MultiRNNCell(RNNCell): self._cells = cells self._state_is_tuple = state_is_tuple if not state_is_tuple: - if any(isinstance(c.state_size, (list, tuple)) for c in self._cells): + if any(_is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells])) @@ -700,7 +783,7 @@ class MultiRNNCell(RNNCell): for i, cell in enumerate(self._cells): with vs.variable_scope("Cell%d" % i): if self._state_is_tuple: - if not isinstance(state, (list, tuple)): + if not _is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) @@ -778,9 +861,9 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None): Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ - if args is None or (isinstance(args, (list, tuple)) and not args): + if args is None or (_is_sequence(args) and not args): raise ValueError("`args` must be specified") - if not isinstance(args, (list, tuple)): + if not _is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 1.