From ae5c66e3c2ce9bcf0f20588f2c55c7d12b8b76ed Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@gmail.com>
Date: Tue, 24 May 2016 14:06:37 -0800
Subject: [PATCH] Add support for arbitrarily nested tuples for RNN state.

Also fixed a bug in the RNN unit tests.
Change: 123150781
---
 .../python/kernel_tests/rnn_cell_test.py      |   7 +-
 tensorflow/python/kernel_tests/rnn_test.py    | 219 ++++++++++++++----
 tensorflow/python/ops/rnn.py                  |  62 +++--
 tensorflow/python/ops/rnn_cell.py             | 111 +++++++--
 4 files changed, 311 insertions(+), 88 deletions(-)

diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index e3756e03d25..10a1a2e2a39 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -76,7 +76,7 @@ class RNNCellTest(tf.test.TestCase):
       with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
         x = tf.zeros([1, 3])  # Test GRUCell with input_size != num_units.
         m = tf.zeros([1, 2])
-        g, _ = tf.nn.rnn_cell.GRUCell(2, input_size=3)(x, m)
+        g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
         sess.run([tf.initialize_all_variables()])
         res = sess.run([g], {x.name: np.array([[1., 1., 1.]]),
                              m.name: np.array([[0.1, 0.1]])})
@@ -104,7 +104,7 @@ class RNNCellTest(tf.test.TestCase):
       with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
         x = tf.zeros([1, 3])  # Test BasicLSTMCell with input_size != num_units.
         m = tf.zeros([1, 4])
-        g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2, input_size=3)(x, m)
+        g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
         sess.run([tf.initialize_all_variables()])
         res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
                                     m.name: 0.1 * np.ones([1, 4])})
@@ -147,8 +147,7 @@ class RNNCellTest(tf.test.TestCase):
         x = tf.zeros([batch_size, input_size])
         m = tf.zeros([batch_size, state_size])
         output, state = tf.nn.rnn_cell.LSTMCell(
-            num_units=num_units, input_size=input_size,
-            num_proj=num_proj, forget_bias=1.0)(x, m)
+            num_units=num_units, num_proj=num_proj, forget_bias=1.0)(x, m)
         sess.run([tf.initialize_all_variables()])
         res = sess.run([output, state],
                        {x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 646c981791d..469635ae4f8 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -26,9 +26,15 @@ import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
 
+from tensorflow.python.ops import rnn_cell
 
-def _flatten(list_of_lists):
-  return [x for y in list_of_lists for x in y]
+# pylint: disable=protected-access
+_is_sequence = rnn_cell._is_sequence
+_unpacked_state = rnn_cell._unpacked_state
+_packed_state = rnn_cell._packed_state
+# pylint: enable=protected-access
+
+_flatten = _unpacked_state
 
 
 class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
@@ -48,24 +54,32 @@ class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
 
 class TestStateSaver(object):
 
-  def __init__(self, batch_size, state_size, state_is_tuple=False):
+  def __init__(self, batch_size, state_size):
     self._batch_size = batch_size
     self._state_size = state_size
-    self._state_is_tuple = state_is_tuple
     self.saved_state = {}
 
-  def state(self, _):
-    if self._state_is_tuple:
-      return tuple(
-          tf.zeros(tf.pack([self._batch_size, s])) for s in self._state_size)
+  def state(self, name):
+    if isinstance(self._state_size, dict):
+      return tf.zeros([self._batch_size, self._state_size[name]])
     else:
-      return tf.zeros(tf.pack([self._batch_size, self._state_size]))
+      return tf.zeros([self._batch_size, self._state_size])
 
   def save_state(self, name, state):
     self.saved_state[name] = state
     return tf.identity(state)
 
 
+class PackStateTest(tf.test.TestCase):
+
+  def testPackUnpackState(self):
+    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
+    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
+    self.assertEqual(_unpacked_state(structure), (3, 4, 5, 6, 7, 9, 10, 8))
+    self.assertEqual(_packed_state(structure, flat),
+                     (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
+
+
 class RNNTest(tf.test.TestCase):
 
   def setUp(self):
@@ -197,7 +211,7 @@ class GRUTest(tf.test.TestCase):
       concat_inputs = tf.placeholder(
           tf.float32, shape=(time_steps, batch_size, input_size))
 
-      cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size)
+      cell = tf.nn.rnn_cell.GRUCell(num_units=num_units)
 
       with tf.variable_scope("dynamic_scope"):
         outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
@@ -229,8 +243,7 @@ class LSTMTest(tf.test.TestCase):
     max_length = 8
     with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
-      cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, initializer=initializer)
+      cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
       outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -250,8 +263,7 @@ class LSTMTest(tf.test.TestCase):
     with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
-          cell_clip=0.0, initializer=initializer)
+          num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
       outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -276,7 +288,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
       state_saver = TestStateSaver(batch_size, 2 * num_units)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=False, initializer=initializer)
+          num_units, use_peepholes=False, initializer=initializer)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
       with tf.variable_scope("share_scope"):
@@ -293,16 +305,16 @@ class LSTMTest(tf.test.TestCase):
           feed_dict={inputs[0]: input_value})
       self.assertAllEqual(last_state_value, saved_state_value)
 
-  def _testNoProjNoShardingTupleStateSaver(self, use_gpu):
+  def testNoProjNoShardingTupleStateSaver(self):
     num_units = 3
     input_size = 5
     batch_size = 2
     max_length = 8
-    with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+    with self.test_session(graph=tf.Graph()) as sess:
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
-      state_saver = TestStateSaver(batch_size, (num_units, num_units))
+      state_saver = TestStateSaver(batch_size, num_units)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=False, initializer=initializer,
+          num_units, use_peepholes=False, initializer=initializer,
           state_is_tuple=True)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
@@ -316,10 +328,70 @@ class LSTMTest(tf.test.TestCase):
       tf.initialize_all_variables().run()
       input_value = np.random.randn(batch_size, input_size)
       last_and_saved_states = sess.run(
-          state + state_saver.saved_state.values(),
+          state + (state_saver.saved_state["c"], state_saver.saved_state["m"]),
           feed_dict={inputs[0]: input_value})
       self.assertEqual(4, len(last_and_saved_states))
-      self.assertEqual(last_and_saved_states[:2], last_and_saved_states[2:])
+      self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:])
+
+  def testNoProjNoShardingNestedTupleStateSaver(self):
+    num_units = 3
+    input_size = 5
+    batch_size = 2
+    max_length = 8
+    with self.test_session(graph=tf.Graph()) as sess:
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+      state_saver = TestStateSaver(batch_size, {"c0": num_units,
+                                                "m0": num_units,
+                                                "c1": num_units + 1,
+                                                "m1": num_units + 1,
+                                                "c2": num_units + 2,
+                                                "m2": num_units + 2,
+                                                "c3": num_units + 3,
+                                                "m3": num_units + 3})
+      def _cell(i):
+        return tf.nn.rnn_cell.LSTMCell(
+            num_units + i, use_peepholes=False, initializer=initializer,
+            state_is_tuple=True)
+
+      # This creates a state tuple which has 4 sub-tuples of length 2 each.
+      cell = tf.nn.rnn_cell.MultiRNNCell(
+          [_cell(i) for i in range(4)], state_is_tuple=True)
+
+      self.assertEqual(len(cell.state_size), 4)
+      for i in range(4):
+        self.assertEqual(len(cell.state_size[i]), 2)
+
+      inputs = max_length * [
+          tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+
+      state_names = (("c0", "m0"), ("c1", "m1"),
+                     ("c2", "m2"), ("c3", "m3"))
+      with tf.variable_scope("share_scope"):
+        outputs, state = tf.nn.state_saving_rnn(
+            cell, inputs, state_saver=state_saver, state_name=state_names)
+      self.assertEqual(len(outputs), len(inputs))
+
+      # Final output comes from _cell(3) which has state size num_units + 3
+      for out in outputs:
+        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3])
+
+      tf.initialize_all_variables().run()
+      input_value = np.random.randn(batch_size, input_size)
+      last_states = sess.run(
+          list(_unpacked_state(state)), feed_dict={inputs[0]: input_value})
+      saved_states = sess.run(
+          list(state_saver.saved_state.values()),
+          feed_dict={inputs[0]: input_value})
+      self.assertEqual(8, len(last_states))
+      self.assertEqual(8, len(saved_states))
+      flat_state_names = _unpacked_state(state_names)
+      named_saved_states = dict(
+          zip(state_saver.saved_state.keys(), saved_states))
+
+      for i in range(8):
+        self.assertAllEqual(
+            last_states[i],
+            named_saved_states[flat_state_names[i]])
 
   def _testProjNoSharding(self, use_gpu):
     num_units = 3
@@ -332,7 +404,7 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
       outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
       self.assertEqual(len(outputs), len(inputs))
@@ -353,21 +425,21 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell_notuple = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
       cell_tuple = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer, state_is_tuple=True)
       outputs_notuple, state_notuple = tf.nn.rnn(
           cell_notuple, inputs, dtype=tf.float32,
           sequence_length=sequence_length)
       tf.get_variable_scope().reuse_variables()
-      outputs_tuple, state_is_tuple = tf.nn.rnn(
+      outputs_tuple, state_tuple = tf.nn.rnn(
           cell_tuple, inputs, dtype=tf.float32,
           sequence_length=sequence_length)
       self.assertEqual(len(outputs_notuple), len(inputs))
       self.assertEqual(len(outputs_tuple), len(inputs))
-      self.assertTrue(isinstance(state_is_tuple, tuple))
+      self.assertTrue(isinstance(state_tuple, tuple))
       self.assertTrue(isinstance(state_notuple, tf.Tensor))
 
       tf.initialize_all_variables().run()
@@ -380,9 +452,9 @@ class LSTMTest(tf.test.TestCase):
 
       (state_notuple_v,) = sess.run(
           (state_notuple,), feed_dict={inputs[0]: input_value})
-      state_is_tuple_v = sess.run(
-          state_is_tuple, feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(state_notuple_v, np.hstack(state_is_tuple_v))
+      state_tuple_v = sess.run(
+          state_tuple, feed_dict={inputs[0]: input_value})
+      self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v))
 
   def _testProjSharding(self, use_gpu):
     num_units = 3
@@ -400,7 +472,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -430,7 +501,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -455,7 +525,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -487,7 +556,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.constant_initializer(0.001)
 
       cell_noshard = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size,
+          num_units,
           num_proj=num_proj,
           use_peepholes=True,
           initializer=initializer,
@@ -495,7 +564,7 @@ class LSTMTest(tf.test.TestCase):
           num_proj_shards=num_proj_shards)
 
       cell_shard = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           initializer=initializer, num_proj=num_proj)
 
       with tf.variable_scope("noshard_scope"):
@@ -541,7 +610,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -577,10 +645,10 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
       cell_d = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer_d)
 
       with tf.variable_scope("share_scope"):
@@ -616,7 +684,7 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
 
       with tf.name_scope("scope0"):
@@ -649,7 +717,7 @@ class LSTMTest(tf.test.TestCase):
           tf.placeholder(tf.float32, shape=(None, input_size))]
       inputs_c = tf.pack(inputs)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer, state_is_tuple=True)
       outputs_static, state_static = tf.nn.rnn(
           cell, inputs, dtype=tf.float32,
@@ -675,6 +743,61 @@ class LSTMTest(tf.test.TestCase):
       self.assertAllEqual(
           np.hstack(state_static_v), np.hstack(state_dynamic_v))
 
+  def testDynamicRNNWithNestedTupleStates(self):
+    num_units = 3
+    input_size = 5
+    batch_size = 2
+    num_proj = 4
+    max_length = 8
+    sequence_length = [4, 6]
+    with self.test_session(graph=tf.Graph()) as sess:
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+      inputs = max_length * [
+          tf.placeholder(tf.float32, shape=(None, input_size))]
+      inputs_c = tf.pack(inputs)
+      def _cell(i):
+        return tf.nn.rnn_cell.LSTMCell(
+            num_units + i, use_peepholes=True,
+            num_proj=num_proj + i, initializer=initializer, state_is_tuple=True)
+
+      # This creates a state tuple which has 4 sub-tuples of length 2 each.
+      cell = tf.nn.rnn_cell.MultiRNNCell(
+          [_cell(i) for i in range(4)], state_is_tuple=True)
+
+      self.assertEqual(len(cell.state_size), 4)
+      for i in range(4):
+        self.assertEqual(len(cell.state_size[i]), 2)
+
+      test_zero = cell.zero_state(1, tf.float32)
+      self.assertEqual(len(test_zero), 4)
+      for i in range(4):
+        self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0])
+        self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
+
+      outputs_static, state_static = tf.nn.rnn(
+          cell, inputs, dtype=tf.float32,
+          sequence_length=sequence_length)
+      tf.get_variable_scope().reuse_variables()
+      outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
+          cell, inputs_c, dtype=tf.float32, time_major=True,
+          sequence_length=sequence_length)
+
+      tf.initialize_all_variables().run()
+
+      input_value = np.random.randn(batch_size, input_size)
+      outputs_static_v = sess.run(
+          outputs_static, feed_dict={inputs[0]: input_value})
+      outputs_dynamic_v = sess.run(
+          outputs_dynamic, feed_dict={inputs[0]: input_value})
+      self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
+
+      state_static_v = sess.run(
+          _unpacked_state(state_static), feed_dict={inputs[0]: input_value})
+      state_dynamic_v = sess.run(
+          _unpacked_state(state_dynamic), feed_dict={inputs[0]: input_value})
+      self.assertAllEqual(
+          np.hstack(state_static_v), np.hstack(state_dynamic_v))
+
   def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
     time_steps = 8
     num_units = 3
@@ -697,7 +820,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
 
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           initializer=initializer, num_proj=num_proj)
 
       with tf.variable_scope("dynamic_scope"):
@@ -752,7 +875,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
 
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           initializer=initializer, num_proj=num_proj)
 
       with tf.variable_scope("dynamic_scope"):
@@ -1010,8 +1133,7 @@ def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
   (_, input_size) = inputs_list_t[0].get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.rnn(
       cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
 
@@ -1025,8 +1147,7 @@ def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
   (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.dynamic_rnn(
       cell, inputs_t, sequence_length=sequence_length, dtype=tf.float32)
 
@@ -1129,8 +1250,7 @@ def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
   (_, input_size) = inputs_list_t[0].get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.rnn(
       cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
 
@@ -1183,7 +1303,7 @@ def _concat_state_vs_tuple_state_rnn_benchmark(
   (_, input_size) = inputs_list_t[0].get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
+      num_units=input_size, use_peepholes=True,
       initializer=initializer, state_is_tuple=state_is_tuple)
   outputs, final_state = tf.nn.rnn(
       cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
@@ -1239,8 +1359,7 @@ def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length,
   (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.dynamic_rnn(
       cell, inputs_t, sequence_length=sequence_length,
       swap_memory=swap_memory, dtype=tf.float32)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 259811c2a4f..6d9a0d4e3f2 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -32,6 +32,13 @@ from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import variable_scope as vs
 
 
+# pylint: disable=protected-access
+_is_sequence = rnn_cell._is_sequence
+_unpacked_state = rnn_cell._unpacked_state
+_packed_state = rnn_cell._packed_state
+# pylint: enable=protected-access
+
+
 def rnn(cell, inputs, initial_state=None, dtype=None,
         sequence_length=None, scope=None):
   """Creates a recurrent neural network specified by RNNCell `cell`.
@@ -177,20 +184,26 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
      type of `state_name` does not match that of `cell.state_size`.
   """
   state_size = cell.state_size
-  state_is_tuple = isinstance(state_size, (list, tuple))
-  state_name_tuple = isinstance(state_name, (list, tuple))
+  state_is_tuple = _is_sequence(state_size)
+  state_name_tuple = _is_sequence(state_name)
 
   if state_is_tuple != state_name_tuple:
     raise ValueError(
-        "state_name should be a tuple iff cell.state_size is.  state_name: %s, "
-        "cell.state_size: %s" % (str(state_name), str(state_size)))
+        "state_name should be the same type as cell.state_size.  "
+        "state_name: %s, cell.state_size: %s"
+        % (str(state_name), str(state_size)))
 
   if state_is_tuple:
-    if len(state_name) != len(state_size):
-      raise ValueError("len(state_name) != len(state_size): %d vs. %d"
-                       % (len(state_name), len(state_size)))
+    state_name_flat = _unpacked_state(state_name)
+    state_size_flat = _unpacked_state(state_size)
 
-    initial_state = tuple(state_saver.state(n) for n in state_name)
+    if len(state_name_flat) != len(state_size_flat):
+      raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d"
+                       % (len(state_name_flat), len(state_size_flat)))
+
+    initial_state = _packed_state(
+        structure=state_name,
+        state=[state_saver.state(n) for n in state_name_flat])
   else:
     initial_state = state_saver.state(state_name)
 
@@ -198,8 +211,10 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
                          sequence_length=sequence_length, scope=scope)
 
   if state_is_tuple:
+    state_flat = _unpacked_state(state)
     save_state = [
-        state_saver.save_state(n, s) for (n, s) in zip(state_name, state)]
+        state_saver.save_state(n, s)
+        for (n, s) in zip(state_name_flat, state_flat)]
   else:
     save_state = [state_saver.save_state(state_name, state)]
 
@@ -262,9 +277,10 @@ def _rnn_step(
       that returned by `state_size`.
   """
 
-  state_is_tuple = isinstance(state, (list, tuple))
+  state_is_tuple = _is_sequence(state)
+  orig_state = state
   # Convert state to a list for ease of use
-  state = list(state) if state_is_tuple else [state]
+  state = list(_unpacked_state(state)) if state_is_tuple else [state]
   state_shape = [s.get_shape() for s in state]
 
   def _copy_some_through(new_output, new_state):
@@ -279,7 +295,8 @@ def _rnn_step(
   def _maybe_copy_some_through():
     """Run RNN step.  Pass through either no or some past state."""
     new_output, new_state = call_cell()
-    new_state = list(new_state) if state_is_tuple else [new_state]
+    new_state = (
+        list(_unpacked_state(new_state)) if state_is_tuple else [new_state])
 
     if len(state) != len(new_state):
       raise ValueError(
@@ -300,7 +317,8 @@ def _rnn_step(
     # steps.  This is faster when max_seq_len is equal to the number of unrolls
     # (which is typical for dynamic_rnn).
     new_output, new_state = call_cell()
-    new_state = list(new_state) if state_is_tuple else [new_state]
+    new_state = (
+        list(_unpacked_state(new_state)) if state_is_tuple else [new_state])
 
     if len(state) != len(new_state):
       raise ValueError(
@@ -325,7 +343,9 @@ def _rnn_step(
     final_state_i.set_shape(state_shape_i)
 
   if state_is_tuple:
-    return (final_output, tuple(final_state))
+    return (
+        final_output,
+        _packed_state(structure=orig_state, state=final_state))
   else:
     return (final_output, final_state[0])
 
@@ -613,9 +633,9 @@ def _dynamic_rnn_loop(
   time = array_ops.constant(0, dtype=dtypes.int32, name="time")
 
   state_size = cell.state_size
-  state_is_tuple = isinstance(state_size, (list, tuple))
+  state_is_tuple = _is_sequence(state_size)
 
-  state = tuple(state) if state_is_tuple else (state,)
+  state = _unpacked_state(state) if state_is_tuple else (state,)
 
   with ops.op_scope([], "dynamic_rnn") as scope:
     base_name = scope
@@ -646,8 +666,9 @@ def _dynamic_rnn_loop(
     # Restore some shape information
     input_t.set_shape([const_batch_size, const_depth])
 
-    # Unpack state if not using state tuples
-    state = tuple(state) if state_is_tuple else state[0]
+    # Pack state back up for use by cell
+    state = (_packed_state(structure=state_size, state=state)
+             if state_is_tuple else state[0])
 
     call_cell = lambda: cell(input_t, state)
 
@@ -665,7 +686,7 @@ def _dynamic_rnn_loop(
       (output, new_state) = call_cell()
 
     # Pack state if using state tuples
-    new_state = tuple(new_state) if state_is_tuple else (new_state,)
+    new_state = _unpacked_state(new_state) if state_is_tuple else (new_state,)
 
     output_ta_t = output_ta_t.write(time, output)
 
@@ -686,6 +707,7 @@ def _dynamic_rnn_loop(
       const_time_steps, const_batch_size, cell.output_size])
 
   # Unpack final state if not using state tuples.
-  final_state = tuple(final_state) if state_is_tuple else final_state[0]
+  final_state = (
+      _unpacked_state(final_state) if state_is_tuple else final_state[0])
 
   return (final_outputs, final_state)
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index bfd0758883b..69ff7775d52 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import math
 
-# pylint: disable=redefined-builtin,unused-import
-from six.moves import xrange
-# pylint: enable=redefined-builtin,unused-import
+import six
 
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -39,6 +38,88 @@ from tensorflow.python.ops.math_ops import tanh
 from tensorflow.python.platform import tf_logging as logging
 
 
+def _is_sequence(seq):
+  return (isinstance(seq, collections.Sequence)
+          and not isinstance(seq, six.string_types))
+
+
+def _packed_state_with_indices(structure, flat, index):
+  """Helper function for _packed_state.
+
+  Args:
+    structure: Substructure (tuple of elements and/or tuples) to mimic
+    flat: Flattened values to output substructure for.
+    index: Index at which to start reading from flat.
+
+  Returns:
+    The tuple (new_index, child), where:
+      * new_index - the updated index into `flat` having processed `structure`.
+      * packed - the subset of `flat` corresponding to `structure`,
+                 having started at `index`, and packed into the same nested
+                 format.
+
+  Raises:
+    ValueError: if `structure` contains more elements than `flat`
+      (assuming indexing starts from `index`).
+  """
+  packed = []
+  for s in structure:
+    if _is_sequence(s):
+      new_index, child = _packed_state_with_indices(s, flat, index)
+      packed.append(type(s)(child))
+      index = new_index
+    else:
+      packed.append(flat[index])
+      index += 1
+  return (index, packed)
+
+
+def _yield_unpacked_state(state):
+  for s in state:
+    if _is_sequence(s):
+      for si in _yield_unpacked_state(s):
+        yield si
+    else:
+      yield s
+
+
+def _unpacked_state(state):
+  if not _is_sequence(state):
+    raise TypeError("state must be a sequence")
+  return type(state)(_yield_unpacked_state(state))
+
+
+def _packed_state(structure, state):
+  """Returns the flat state packed into a recursive tuple like structure.
+
+  Args:
+    structure: tuple or list constructed of scalars and/or other tuples/lists.
+    state: flattened state.
+
+  Returns:
+    packed: `state` converted to have the same recursive structure as
+      `structure`.
+
+  Raises:
+    TypeError: If structure or state is not a tuple or list.
+    ValueError: If state and structure have different element counts.
+  """
+  if not _is_sequence(structure):
+    raise TypeError("structure must be a sequence")
+  if not _is_sequence(state):
+    raise TypeError("state must be a sequence")
+
+  flat_structure = _unpacked_state(structure)
+  if len(flat_structure) != len(state):
+    raise ValueError(
+        "Internal error: Could not pack state.  Structure had %d elements, but "
+        "state had %d elements.  Structure: %s, state: %s."
+        % (len(flat_structure), len(state), structure, state))
+
+  (_, packed) = _packed_state_with_indices(structure, state, 0)
+  return type(structure)(packed)
+
+
 class RNNCell(object):
   """Abstract object representing an RNN cell.
 
@@ -98,17 +179,19 @@ class RNNCell(object):
       If `state_size` is an int, then the return value is a `2-D` tensor of
       shape `[batch_size x state_size]` filled with zeros.
 
-      If `state_size` is a list or tuple of ints, then the return value is
-      a tuple of `2-D` tensors with shape
-      `[batch_size x s] for s in state_size`.
+      If `state_size` is a nested list or tuple, then the return value is
+      a nested list or tuple (of the same structure) of `2-D` tensors with
+    the shapes `[batch_size x s]` for each s in `state_size`.
     """
     state_size = self.state_size
-    if isinstance(state_size, (list, tuple)):
-      zeros = tuple(
+    if _is_sequence(state_size):
+      state_size_flat = _unpacked_state(state_size)
+      zeros_flat = [
           array_ops.zeros(array_ops.pack([batch_size, s]), dtype=dtype)
-          for s in state_size)
-      for s, z in zip(state_size, zeros):
+          for s in state_size_flat]
+      for s, z in zip(state_size_flat, zeros_flat):
         z.set_shape([None, s])
+      zeros = _packed_state(structure=state_size, state=zeros_flat)
     else:
       zeros = array_ops.zeros(
           array_ops.pack([batch_size, state_size]), dtype=dtype)
@@ -675,7 +758,7 @@ class MultiRNNCell(RNNCell):
     self._cells = cells
     self._state_is_tuple = state_is_tuple
     if not state_is_tuple:
-      if any(isinstance(c.state_size, (list, tuple)) for c in self._cells):
+      if any(_is_sequence(c.state_size) for c in self._cells):
         raise ValueError("Some cells return tuples of states, but the flag "
                          "state_is_tuple is not set.  State sizes are: %s"
                          % str([c.state_size for c in self._cells]))
@@ -700,7 +783,7 @@ class MultiRNNCell(RNNCell):
       for i, cell in enumerate(self._cells):
         with vs.variable_scope("Cell%d" % i):
           if self._state_is_tuple:
-            if not isinstance(state, (list, tuple)):
+            if not _is_sequence(state):
               raise ValueError(
                   "Expected state to be a tuple of length %d, but received: %s"
                   % (len(self.state_size), state))
@@ -778,9 +861,9 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
   Raises:
     ValueError: if some of the arguments has unspecified or wrong shape.
   """
-  if args is None or (isinstance(args, (list, tuple)) and not args):
+  if args is None or (_is_sequence(args) and not args):
     raise ValueError("`args` must be specified")
-  if not isinstance(args, (list, tuple)):
+  if not _is_sequence(args):
     args = [args]
 
   # Calculate the total size of arguments on dimension 1.