** BREAKING CHANGE **

Core RNNCell implementations now use state_is_tuple=True by default

This is part of the deprecation process for non-tuple LSTM and MultiRNNCell
states.
Change: 130059769
This commit is contained in:
Eugene Brevdo 2016-08-11 18:06:42 -08:00 committed by TensorFlower Gardener
parent fc218ef4d5
commit bc15695781
5 changed files with 417 additions and 245 deletions

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for GridRNN cells."""
from __future__ import absolute_import
@ -27,7 +26,8 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2BasicLSTMCell(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)) as root_scope:
with tf.variable_scope(
'root', initializer=tf.constant_initializer(0.2)) as root_scope:
x = tf.zeros([1, 3])
m = tf.zeros([1, 8])
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2)
@ -38,15 +38,18 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
res = sess.run(
[g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
0.72320831, 0.80555487, 0.39102408, 0.42150158]])
self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
0.36617181, 0.72320831, 0.80555487,
0.39102408, 0.42150158]])
# emulate a loop through the input sequence, where we call cell() multiple times
# emulate a loop through the input sequence,
# where we call cell() multiple times
root_scope.reuse_variables()
g2, s2 = cell(x, m)
self.assertEqual(g2.get_shape(), (1, 2))
@ -56,8 +59,9 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.58847463, 0.58847463]])
self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, 0.58847463,
0.97726452, 1.04626071, 0.4927212, 0.51137757]])
self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463,
0.58847463, 0.97726452, 1.04626071,
0.4927212, 0.51137757]])
def testGrid2BasicLSTMCellTied(self):
with self.test_session() as sess:
@ -72,27 +76,31 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
res = sess.run(
[g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
0.72320831, 0.80555487, 0.39102408, 0.42150158]])
self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
0.36617181, 0.72320831, 0.80555487,
0.39102408, 0.42150158]])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.36703536, 0.36703536]])
self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, 0.36703536,
0.80941606, 0.87550586, 0.40108523, 0.42199609]])
self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536,
0.36703536, 0.80941606, 0.87550586,
0.40108523, 0.42199609]])
def testGrid2BasicLSTMCellWithRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 4])
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=False, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(
2, tied=False, non_recurrent_fn=tf.nn.relu)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
@ -104,11 +112,11 @@ class GridRNNCellTest(tf.test.TestCase):
m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[ 0.31667367, 0.31667367]])
self.assertAllClose(res[1], [[ 0.29530135, 0.37520045, 0.17044567, 0.21292259]])
self.assertAllClose(res[0], [[0.31667367, 0.31667367]])
self.assertAllClose(res[1],
[[0.29530135, 0.37520045, 0.17044567, 0.21292259]])
"""
LSTMCell
"""LSTMCell
"""
def testGrid2LSTMCell(self):
@ -124,20 +132,23 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
res = sess.run(
[g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
1.38917875, 1.49043763, 0.83884692, 0.86036491]])
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
0.95686918, 1.38917875, 1.49043763,
0.83884692, 0.86036491]])
def testGrid2LSTMCellTied(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 8])
cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, tied=True, use_peepholes=True)
cell = tf.contrib.grid_rnn.Grid2LSTMCell(
2, tied=True, use_peepholes=True)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
@ -145,20 +156,23 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
res = sess.run(
[g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
1.38917875, 1.49043763, 0.83884692, 0.86036491]])
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
0.95686918, 1.38917875, 1.49043763,
0.83884692, 0.86036491]])
def testGrid2LSTMCellWithRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 4])
cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.Grid2LSTMCell(
2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
@ -170,11 +184,11 @@ class GridRNNCellTest(tf.test.TestCase):
m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[ 2.1831727, 2.1831727]])
self.assertAllClose(res[1], [[ 0.92270052, 1.02325559, 0.66159075, 0.70475441]])
self.assertAllClose(res[0], [[2.1831727, 2.1831727]])
self.assertAllClose(res[1],
[[0.92270052, 1.02325559, 0.66159075, 0.70475441]])
"""
RNNCell
"""RNNCell
"""
def testGrid2BasicRNNCell(self):
@ -190,14 +204,16 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (2, 4))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
res = sess.run(
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 4))
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
[0.99480951, 0.99480951]])
self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
self.assertAllClose(res[1],
[[0.94685763, 0.94685763, 0.80049908, 0.80049908],
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
def testGrid2BasicRNNCellTied(self):
with self.test_session() as sess:
@ -212,21 +228,24 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (2, 4))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
res = sess.run(
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 4))
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
[0.99480951, 0.99480951]])
self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
self.assertAllClose(res[1],
[[0.94685763, 0.94685763, 0.80049908, 0.80049908],
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
def testGrid2BasicRNNCellWithRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(
2, non_recurrent_fn=tf.nn.relu)
self.assertEqual(cell.state_size, 2)
g, s = cell(x, m)
@ -234,19 +253,20 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 2))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])})
res = sess.run([g, s], {x: np.array([[1., 1.]]),
m: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 2))
self.assertAllClose(res[0], [[1.80049896, 1.80049896]])
self.assertAllClose(res[1], [[0.80049896, 0.80049896]])
"""
1-LSTM
"""1-LSTM
"""
def testGrid1LSTMCell(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)) as root_scope:
with tf.variable_scope(
'root', initializer=tf.constant_initializer(0.5)) as root_scope:
x = tf.zeros([1, 3])
m = tf.zeros([1, 4])
cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True)
@ -262,7 +282,8 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.91287315, 0.91287315]])
self.assertAllClose(res[1], [[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
self.assertAllClose(res[1],
[[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
root_scope.reuse_variables()
@ -276,7 +297,8 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.9032144, 0.9032144]])
self.assertAllClose(res[1], [[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
self.assertAllClose(res[1],
[[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
g3, s3 = cell(x2, m)
self.assertEqual(g3.get_shape(), (1, 2))
@ -287,11 +309,12 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.92727238, 0.92727238]])
self.assertAllClose(res[1], [[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
self.assertAllClose(res[1],
[[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
"""3-LSTM
"""
3-LSTM
"""
def testGrid3LSTMCell(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
@ -306,18 +329,20 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4]])})
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
0.8, -0.1, -0.2, -0.3, -0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 12))
self.assertAllClose(res[0], [[0.96892911, 0.96892911]])
self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, 0.96892911,
1.33592629, 1.4373529, 0.80867189, 0.83247656,
0.7317788, 0.63205892, 0.56548983, 0.50446129]])
self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911,
0.96892911, 1.33592629, 1.4373529,
0.80867189, 0.83247656, 0.7317788,
0.63205892, 0.56548983, 0.50446129]])
"""Edge cases
"""
Edge cases
"""
def testGridRNNEdgeCasesLikeRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
@ -325,8 +350,13 @@ class GridRNNCellTest(tf.test.TestCase):
m = tf.zeros([0, 0])
# this is equivalent to relu
cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=1, input_dims=0, output_dims=0,
non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.GridRNNCell(
num_units=2,
num_dims=1,
input_dims=0,
output_dims=0,
non_recurrent_dims=0,
non_recurrent_fn=tf.nn.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (3, 2))
self.assertEqual(s.get_shape(), (0, 0))
@ -340,12 +370,17 @@ class GridRNNCellTest(tf.test.TestCase):
def testGridRNNEdgeCasesNoOutput(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
x = tf.zeros([1, 2])
m = tf.zeros([1, 4])
# This cell produces no output
cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=2, input_dims=0, output_dims=None,
non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.GridRNNCell(
num_units=2,
num_dims=2,
input_dims=0,
output_dims=None,
non_recurrent_dims=0,
non_recurrent_fn=tf.nn.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (0, 0))
self.assertEqual(s.get_shape(), (1, 4))
@ -356,8 +391,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (0, 0))
self.assertEqual(res[1].shape, (1, 4))
"""
Test with tf.nn.rnn
"""Test with tf.nn.rnn
"""
def testGrid2LSTMCellWithRNN(self):
@ -370,7 +404,9 @@ class GridRNNCellTest(tf.test.TestCase):
cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
tf.placeholder(
tf.float32, shape=(batch_size, input_size))
]
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -386,8 +422,7 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state],
feed_dict={inputs[0]: input_value})
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
@ -398,10 +433,13 @@ class GridRNNCellTest(tf.test.TestCase):
num_units = 2
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.Grid2LSTMCell(
num_units=num_units, non_recurrent_fn=tf.nn.relu)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
tf.placeholder(
tf.float32, shape=(batch_size, input_size))
]
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -417,8 +455,7 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state],
feed_dict={inputs[0]: input_value})
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
@ -429,10 +466,13 @@ class GridRNNCellTest(tf.test.TestCase):
num_units = 2
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
cell = tf.contrib.grid_rnn.Grid3LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
cell = tf.contrib.grid_rnn.Grid3LSTMCell(
num_units=num_units, non_recurrent_fn=tf.nn.relu)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
tf.placeholder(
tf.float32, shape=(batch_size, input_size))
]
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -448,12 +488,10 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state],
feed_dict={inputs[0]: input_value})
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
def testGrid1LSTMCellWithRNN(self):
batch_size = 3
input_size = 5
@ -464,8 +502,8 @@ class GridRNNCellTest(tf.test.TestCase):
cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units)
# for 1-LSTM, we only feed the first step
inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))] \
+ (max_length - 1) * [tf.zeros([batch_size, input_size])]
inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ (max_length - 1) * [tf.zeros([batch_size, input_size])])
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -480,10 +518,10 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state],
feed_dict={inputs[0]: input_value})
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
if __name__ == "__main__":
if __name__ == '__main__':
tf.test.main()

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# TODO(b/28879898) Fix all lint issues and clean the code.
"""Module for constructing GridRNN cells"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -35,46 +35,67 @@ class GridRNNCell(rnn_cell.RNNCell):
http://arxiv.org/pdf/1507.01526v3.pdf
This is the generic implementation of GridRNN. Users can specify arbitrary number of dimensions,
This is the generic implementation of GridRNN. Users can specify arbitrary
number of dimensions,
set some of them to be priority (section 3.2), non-recurrent (section 3.3)
and input/output dimensions (section 3.4).
Weight sharing can also be specified using the `tied` parameter.
Type of recurrent units can be specified via `cell_fn`.
"""
def __init__(self, num_units, num_dims=1, input_dims=None, output_dims=None, priority_dims=None,
non_recurrent_dims=None, tied=False, cell_fn=None, non_recurrent_fn=None):
def __init__(self,
num_units,
num_dims=1,
input_dims=None,
output_dims=None,
priority_dims=None,
non_recurrent_dims=None,
tied=False,
cell_fn=None,
non_recurrent_fn=None):
"""Initialize the parameters of a Grid RNN cell
Args:
num_units: int, The number of units in all dimensions of this GridRNN cell
num_dims: int, Number of dimensions of this grid.
input_dims: int or list, List of dimensions which will receive input data.
output_dims: int or list, List of dimensions from which the output will be recorded.
priority_dims: int or list, List of dimensions to be considered as priority dimensions.
output_dims: int or list, List of dimensions from which the output will be
recorded.
priority_dims: int or list, List of dimensions to be considered as
priority dimensions.
If None, no dimension is prioritized.
non_recurrent_dims: int or list, List of dimensions that are not recurrent.
The transfer function for non-recurrent dimensions is specified via `non_recurrent_fn`,
non_recurrent_dims: int or list, List of dimensions that are not
recurrent.
The transfer function for non-recurrent dimensions is specified
via `non_recurrent_fn`,
which is default to be `tensorflow.nn.relu`.
tied: bool, Whether to share the weights among the dimensions of this GridRNN cell.
If there are non-recurrent dimensions in the grid, weights are shared between each
tied: bool, Whether to share the weights among the dimensions of this
GridRNN cell.
If there are non-recurrent dimensions in the grid, weights are
shared between each
group of recurrent and non-recurrent dimensions.
cell_fn: function, a function which returns the recurrent cell object. Has to be in the following signature:
cell_fn: function, a function which returns the recurrent cell object. Has
to be in the following signature:
def cell_func(num_units, input_size):
# ...
and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used.
non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions
and returns an object of type `RNNCell`. If None, LSTMCell with
default parameters will be used.
non_recurrent_fn: a tensorflow Op that will be the transfer function of
the non-recurrent dimensions
"""
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
self._config = _parse_rnn_config(num_dims, input_dims, output_dims, priority_dims,
non_recurrent_dims, non_recurrent_fn or nn.relu, tied, num_units)
self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
priority_dims, non_recurrent_dims,
non_recurrent_fn or nn.relu, tied,
num_units)
cell_input_size = (self._config.num_dims - 1) * num_units
if cell_fn is None:
self._cell = rnn_cell.LSTMCell(num_units=num_units, input_size=cell_input_size)
self._cell = rnn_cell.LSTMCell(
num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
else:
self._cell = cell_fn(num_units, cell_input_size)
if not isinstance(self._cell, rnn_cell.RNNCell):
@ -100,7 +121,8 @@ class GridRNNCell(rnn_cell.RNNCell):
Args:
inputs: input Tensor, 2D, batch x input_size. Or None
state: state Tensor, 2D, batch x state_size. Note that state_size = cell_state_size * recurrent_dims
state: state Tensor, 2D, batch x state_size. Note that state_size =
cell_state_size * recurrent_dims
scope: VariableScope for the created subgraph; defaults to "GridRNNCell".
Returns:
@ -112,24 +134,32 @@ class GridRNNCell(rnn_cell.RNNCell):
"""
state_sz = state.get_shape().as_list()[1]
if self.state_size != state_sz:
raise ValueError('Actual state size not same as specified: {} vs {}.'.format(state_sz, self.state_size))
raise ValueError(
'Actual state size not same as specified: {} vs {}.'.format(
state_sz, self.state_size))
conf = self._config
dtype = inputs.dtype if inputs is not None else state.dtype
# c_prev is `m`, and m_prev is `h` in the paper. Keep c and m here for consistency with the codebase
# c_prev is `m`, and m_prev is `h` in the paper.
# Keep c and m here for consistency with the codebase
c_prev = [None] * self._config.num_dims
m_prev = [None] * self._config.num_dims
cell_output_size = self._cell.state_size - conf.num_units
# for LSTM : state = memory cell + output, hence cell_output_size > 0
# for GRU/RNN: state = output (whose size is equal to _num_units), hence cell_output_size = 0
for recurrent_dim, start_idx in zip(self._config.recurrents, range(0, self.state_size, self._cell.state_size)):
# for GRU/RNN: state = output (whose size is equal to _num_units),
# hence cell_output_size = 0
for recurrent_dim, start_idx in zip(self._config.recurrents, range(
0, self.state_size, self._cell.state_size)):
if cell_output_size > 0:
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx + conf.num_units], [-1, cell_output_size])
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units])
m_prev[recurrent_dim] = array_ops.slice(
state, [0, start_idx + conf.num_units], [-1, cell_output_size])
else:
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units])
new_output = [None] * conf.num_dims
new_state = [None] * conf.num_dims
@ -137,150 +167,212 @@ class GridRNNCell(rnn_cell.RNNCell):
with vs.variable_scope(scope or type(self).__name__): # GridRNNCell
# project input
if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(conf.inputs) > 0:
if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(
conf.inputs) > 0:
input_splits = array_ops.split(1, len(conf.inputs), inputs)
input_sz = input_splits[0].get_shape().as_list()[1]
for i, j in enumerate(conf.inputs):
input_project_m = vs.get_variable('project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
input_project_m = vs.get_variable(
'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
if cell_output_size > 0:
input_project_c = vs.get_variable('project_c_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
input_project_c = vs.get_variable(
'project_c_{}'.format(j), [input_sz, conf.num_units],
dtype=dtype)
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
_propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, new_output, new_state, True)
_propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, new_state, False)
_propagate(conf.non_priority, conf, self._cell, c_prev, m_prev,
new_output, new_state, True)
_propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output,
new_state, False)
output_tensors = [new_output[i] for i in self._config.outputs]
output = array_ops.zeros([0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(1,
output_tensors)
output = array_ops.zeros(
[0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(
1, output_tensors)
state_tensors = [new_state[i] for i in self._config.recurrents]
states = array_ops.zeros([0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(1, state_tensors)
states = array_ops.zeros(
[0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(
1, state_tensors)
return output, states
"""Specialized cells, for convenience
"""
Specialized cells, for convenience
"""
class Grid1BasicRNNCell(GridRNNCell):
"""1D BasicRNN cell"""
def __init__(self, num_units):
super(Grid1BasicRNNCell, self).__init__(num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0, tied=False,
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
super(Grid1BasicRNNCell, self).__init__(
num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0, tied=False,
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
class Grid2BasicRNNCell(GridRNNCell):
"""2D BasicRNN cell
This creates a 2D cell which receives input and gives output in the first dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
This creates a 2D cell which receives input and gives output in the first
dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
specified.
"""
def __init__(self, num_units, tied=False, non_recurrent_fn=None):
super(Grid2BasicRNNCell, self).__init__(num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
non_recurrent_fn=non_recurrent_fn)
super(Grid2BasicRNNCell, self).__init__(
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
non_recurrent_fn=non_recurrent_fn)
class Grid1BasicLSTMCell(GridRNNCell):
"""1D BasicLSTM cell"""
def __init__(self, num_units, forget_bias=1):
super(Grid1BasicLSTMCell, self).__init__(num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0, tied=False,
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(num_units=n,
forget_bias=forget_bias, input_size=i))
super(Grid1BasicLSTMCell, self).__init__(
num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0, tied=False,
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
num_units=n,
forget_bias=forget_bias, input_size=i,
state_is_tuple=False))
class Grid2BasicLSTMCell(GridRNNCell):
"""2D BasicLSTM cell
This creates a 2D cell which receives input and gives output in the first dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
This creates a 2D cell which receives input and gives output in the first
dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
specified.
"""
def __init__(self, num_units, tied=False, non_recurrent_fn=None, forget_bias=1):
super(Grid2BasicLSTMCell, self).__init__(num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
num_units=n, forget_bias=forget_bias, input_size=i),
non_recurrent_fn=non_recurrent_fn)
def __init__(self,
num_units,
tied=False,
non_recurrent_fn=None,
forget_bias=1):
super(Grid2BasicLSTMCell, self).__init__(
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
num_units=n, forget_bias=forget_bias, input_size=i,
state_is_tuple=False),
non_recurrent_fn=non_recurrent_fn)
class Grid1LSTMCell(GridRNNCell):
"""1D LSTM cell
This is different from Grid1BasicLSTMCell because it gives options to specify the forget bias and enabling peepholes
This is different from Grid1BasicLSTMCell because it gives options to
specify the forget bias and enabling peepholes
"""
def __init__(self, num_units, use_peepholes=False, forget_bias=1.0):
super(Grid1LSTMCell, self).__init__(num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0,
cell_fn=lambda n, i: rnn_cell.LSTMCell(
num_units=n, input_size=i, use_peepholes=use_peepholes,
forget_bias=forget_bias))
super(Grid1LSTMCell, self).__init__(
num_units=num_units, num_dims=1,
input_dims=0, output_dims=0, priority_dims=0,
cell_fn=lambda n, i: rnn_cell.LSTMCell(
num_units=n, input_size=i, use_peepholes=use_peepholes,
forget_bias=forget_bias, state_is_tuple=False))
class Grid2LSTMCell(GridRNNCell):
"""2D LSTM cell
This creates a 2D cell which receives input and gives output in the first dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
This creates a 2D cell which receives input and gives output in the first
dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
specified.
"""
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
use_peepholes=False, forget_bias=1.0):
super(Grid2LSTMCell, self).__init__(num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.LSTMCell(
num_units=n, input_size=i, forget_bias=forget_bias,
use_peepholes=use_peepholes),
non_recurrent_fn=non_recurrent_fn)
def __init__(self,
num_units,
tied=False,
non_recurrent_fn=None,
use_peepholes=False,
forget_bias=1.0):
super(Grid2LSTMCell, self).__init__(
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.LSTMCell(
num_units=n, input_size=i, forget_bias=forget_bias,
use_peepholes=use_peepholes, state_is_tuple=False),
non_recurrent_fn=non_recurrent_fn)
class Grid3LSTMCell(GridRNNCell):
"""3D BasicLSTM cell
This creates a 2D cell which receives input and gives output in the first dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
This creates a 2D cell which receives input and gives output in the first
dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
specified.
The second and third dimensions are LSTM.
"""
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
use_peepholes=False, forget_bias=1.0):
super(Grid3LSTMCell, self).__init__(num_units=num_units, num_dims=3,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.LSTMCell(
num_units=n, input_size=i, forget_bias=forget_bias,
use_peepholes=use_peepholes),
non_recurrent_fn=non_recurrent_fn)
def __init__(self,
num_units,
tied=False,
non_recurrent_fn=None,
use_peepholes=False,
forget_bias=1.0):
super(Grid3LSTMCell, self).__init__(
num_units=num_units, num_dims=3,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.LSTMCell(
num_units=n, input_size=i, forget_bias=forget_bias,
use_peepholes=use_peepholes, state_is_tuple=False),
non_recurrent_fn=non_recurrent_fn)
class Grid2GRUCell(GridRNNCell):
"""2D LSTM cell
This creates a 2D cell which receives input and gives output in the first dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
This creates a 2D cell which receives input and gives output in the first
dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
specified.
"""
def __init__(self, num_units, tied=False, non_recurrent_fn=None):
super(Grid2GRUCell, self).__init__(num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
non_recurrent_fn=non_recurrent_fn)
super(Grid2GRUCell, self).__init__(
num_units=num_units, num_dims=2,
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
non_recurrent_fn=non_recurrent_fn)
"""
Helpers
"""Helpers
"""
_GridRNNDimension = namedtuple('_GridRNNDimension', ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
_GridRNNDimension = namedtuple(
'_GridRNNDimension',
['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
_GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims',
'inputs', 'outputs', 'recurrents',
'priority', 'non_priority', 'tied', 'num_units'])
_GridRNNConfig = namedtuple('_GridRNNConfig',
['num_dims', 'dims', 'inputs', 'outputs',
'recurrents', 'priority', 'non_priority', 'tied',
'num_units'])
def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims,
non_recurrent_fn, tied, num_units):
def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
def check_dim_list(ls):
if ls is None:
ls = []
@ -288,7 +380,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
ls = [ls]
ls = sorted(set(ls))
if any(_ < 0 or _ >= num_dims for _ in ls):
raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, num_dims))
raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls,
num_dims))
return ls
input_dims = check_dim_list(ls_input_dims)
@ -298,42 +391,58 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
rnn_dims = []
for i in range(num_dims):
rnn_dims.append(_GridRNNDimension(idx=i, is_input=(i in input_dims), is_output=(i in output_dims),
is_priority=(i in priority_dims),
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else None))
return _GridRNNConfig(num_dims=num_dims, dims=rnn_dims, inputs=input_dims, outputs=output_dims,
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
priority=priority_dims,
non_priority=[x for x in range(num_dims) if x not in priority_dims],
tied=tied, num_units=num_units)
rnn_dims.append(
_GridRNNDimension(
idx=i,
is_input=(i in input_dims),
is_output=(i in output_dims),
is_priority=(i in priority_dims),
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
None))
return _GridRNNConfig(
num_dims=num_dims,
dims=rnn_dims,
inputs=input_dims,
outputs=output_dims,
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
priority=priority_dims,
non_priority=[x for x in range(num_dims) if x not in priority_dims],
tied=tied,
num_units=num_units)
def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, first_call):
"""
Propagates through all the cells in dim_indices dimensions.
def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state,
first_call):
"""Propagates through all the cells in dim_indices dimensions.
"""
if len(dim_indices) == 0:
return
# Because of the way RNNCells are implemented, we take the last dimension (H_{N-1}) out
# and feed it as the state of the RNN cell (in `last_dim_output`)
# Because of the way RNNCells are implemented, we take the last dimension
# (H_{N-1}) out and feed it as the state of the RNN cell
# (in `last_dim_output`).
# The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs`
if conf.num_dims > 1:
ls_cell_inputs = [None] * (conf.num_dims - 1)
for d in conf.dims[:-1]:
ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[d.idx] is not None else m_prev[d.idx]
ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[
d.idx] is not None else m_prev[d.idx]
cell_inputs = array_ops.concat(1, ls_cell_inputs)
else:
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype)
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
m_prev[0].dtype)
last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1]
for i in dim_indices:
d = conf.dims[i]
if d.non_recurrent_fn:
linear_args = array_ops.concat(1, [cell_inputs, last_dim_output]) if conf.num_dims > 1 else last_dim_output
with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)):
if conf.tied and not(first_call and i == dim_indices[0]):
linear_args = array_ops.concat(
1, [cell_inputs, last_dim_output
]) if conf.num_dims > 1 else last_dim_output
with vs.variable_scope('non_recurrent' if conf.tied else
'non_recurrent/cell_{}'.format(i)):
if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
new_output[d.idx] = layers.legacy_fully_connected(
linear_args,
@ -348,7 +457,8 @@ def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, f
# for GRU/RNN, the state is just the previous output
cell_state = last_dim_output
with vs.variable_scope('recurrent' if conf.tied else 'recurrent/cell_{}'.format(i)):
with vs.variable_scope('recurrent' if conf.tied else
'recurrent/cell_{}'.format(i)):
if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)

View File

@ -89,7 +89,8 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 2])
m = tf.zeros([1, 8])
g, out_m = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2)(x, m)
[tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)] * 2,
state_is_tuple=False)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
m.name: 0.1 * np.ones([1, 8])})
@ -104,7 +105,7 @@ class RNNCellTest(tf.test.TestCase):
with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3]) # Test BasicLSTMCell with input_size != num_units.
m = tf.zeros([1, 4])
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
m.name: 0.1 * np.ones([1, 4])})
@ -117,7 +118,7 @@ class RNNCellTest(tf.test.TestCase):
m0 = (tf.zeros([1, 2]),) * 2
m1 = (tf.zeros([1, 2]),) * 2
cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)] * 2,
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2,
state_is_tuple=True)
self.assertTrue(isinstance(cell.state_size, tuple))
self.assertTrue(isinstance(cell.state_size[0],
@ -153,7 +154,8 @@ class RNNCellTest(tf.test.TestCase):
m0 = tf.zeros([1, 4])
m1 = tf.zeros([1, 4])
cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2, state_is_tuple=True)
[tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)] * 2,
state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m0, out_m1],
@ -183,7 +185,8 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([batch_size, input_size])
m = tf.zeros([batch_size, state_size])
cell = tf.nn.rnn_cell.LSTMCell(
num_units=num_units, num_proj=num_proj, forget_bias=1.0)
num_units=num_units, num_proj=num_proj, forget_bias=1.0,
state_is_tuple=False)
output, state = cell(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([output, state],
@ -286,7 +289,7 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 2])
m = tf.zeros([1, 4])
_, ml = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.GRUCell(2)] * 2)(x, m)
[tf.nn.rnn_cell.GRUCell(2)] * 2, state_is_tuple=False)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run(ml, {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1, 0.1]])})

View File

@ -363,7 +363,8 @@ class LSTMTest(tf.test.TestCase):
max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer)
cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -383,7 +384,8 @@ class LSTMTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer)
num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -408,7 +410,8 @@ class LSTMTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2 * num_units)
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=False, initializer=initializer)
num_units, use_peepholes=False, initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("share_scope"):
@ -525,7 +528,8 @@ class LSTMTest(tf.test.TestCase):
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
num_proj=num_proj, initializer=initializer,
state_is_tuple=False)
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
@ -546,7 +550,7 @@ class LSTMTest(tf.test.TestCase):
tf.placeholder(tf.float32, shape=(None, input_size))]
cell_notuple = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
num_proj=num_proj, initializer=initializer, state_is_tuple=False)
cell_tuple = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer, state_is_tuple=True)
@ -596,7 +600,8 @@ class LSTMTest(tf.test.TestCase):
num_proj=num_proj,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
initializer=initializer)
initializer=initializer,
state_is_tuple=False)
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -625,7 +630,8 @@ class LSTMTest(tf.test.TestCase):
num_proj=num_proj,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
initializer=initializer)
initializer=initializer,
state_is_tuple=False)
with self.assertRaises(ValueError):
tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -649,7 +655,8 @@ class LSTMTest(tf.test.TestCase):
num_proj=num_proj,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
initializer=initializer)
initializer=initializer,
state_is_tuple=False)
outputs, _ = tf.nn.rnn(
cell, inputs, initial_state=cell.zero_state(batch_size, tf.float64))
@ -681,11 +688,13 @@ class LSTMTest(tf.test.TestCase):
use_peepholes=True,
initializer=initializer,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards)
num_proj_shards=num_proj_shards,
state_is_tuple=False)
cell_shard = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
initializer=initializer, num_proj=num_proj)
initializer=initializer, num_proj=num_proj,
state_is_tuple=False)
with tf.variable_scope("noshard_scope"):
outputs_noshard, state_noshard = tf.nn.rnn(
@ -734,7 +743,8 @@ class LSTMTest(tf.test.TestCase):
num_proj=num_proj,
num_unit_shards=num_unit_shards,
num_proj_shards=num_proj_shards,
initializer=initializer)
initializer=initializer,
state_is_tuple=False)
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
outputs, state = tf.nn.rnn(
@ -766,10 +776,12 @@ class LSTMTest(tf.test.TestCase):
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
num_proj=num_proj, initializer=initializer,
state_is_tuple=False)
cell_d = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer_d)
num_proj=num_proj, initializer=initializer_d,
state_is_tuple=False)
with tf.variable_scope("share_scope"):
outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@ -805,7 +817,8 @@ class LSTMTest(tf.test.TestCase):
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
num_proj=num_proj, initializer=initializer,
state_is_tuple=False)
with tf.name_scope("scope0"):
with tf.variable_scope("share_scope"):
@ -947,7 +960,7 @@ class LSTMTest(tf.test.TestCase):
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
initializer=initializer, num_proj=num_proj)
initializer=initializer, num_proj=num_proj, state_is_tuple=False)
with tf.variable_scope("dynamic_scope"):
outputs_static, state_static = tf.nn.rnn(
@ -1002,7 +1015,7 @@ class LSTMTest(tf.test.TestCase):
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
initializer=initializer, num_proj=num_proj)
initializer=initializer, num_proj=num_proj, state_is_tuple=False)
with tf.variable_scope("dynamic_scope"):
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
@ -1141,10 +1154,12 @@ class BidirectionalRNNTest(tf.test.TestCase):
sequence_length = tf.placeholder(tf.int64) if use_sequence_length else None
cell_fw = tf.nn.rnn_cell.LSTMCell(num_units,
input_size,
initializer=initializer)
initializer=initializer,
state_is_tuple=False)
cell_bw = tf.nn.rnn_cell.LSTMCell(num_units,
input_size,
initializer=initializer)
initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
tf.placeholder(
tf.float32,
@ -1889,7 +1904,8 @@ class StateSaverRNNTest(tf.test.TestCase):
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2 * num_units)
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=False, initializer=initializer)
num_units, use_peepholes=False, initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
return tf.nn.state_saving_rnn(
@ -1906,7 +1922,8 @@ def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
(_, input_size) = inputs_list_t[0].get_shape().as_list()
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
cell = tf.nn.rnn_cell.LSTMCell(
num_units=input_size, use_peepholes=True, initializer=initializer)
num_units=input_size, use_peepholes=True, initializer=initializer,
state_is_tuple=False)
outputs, final_state = tf.nn.rnn(
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
@ -1920,7 +1937,8 @@ def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
(unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
cell = tf.nn.rnn_cell.LSTMCell(
num_units=input_size, use_peepholes=True, initializer=initializer)
num_units=input_size, use_peepholes=True, initializer=initializer,
state_is_tuple=False)
outputs, final_state = tf.nn.dynamic_rnn(
cell, inputs_t, sequence_length=sequence_length, dtype=tf.float32)
@ -2023,7 +2041,8 @@ def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
(_, input_size) = inputs_list_t[0].get_shape().as_list()
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
cell = tf.nn.rnn_cell.LSTMCell(
num_units=input_size, use_peepholes=True, initializer=initializer)
num_units=input_size, use_peepholes=True, initializer=initializer,
state_is_tuple=False)
outputs, final_state = tf.nn.rnn(
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
@ -2132,7 +2151,8 @@ def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length,
(unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
cell = tf.nn.rnn_cell.LSTMCell(
num_units=input_size, use_peepholes=True, initializer=initializer)
num_units=input_size, use_peepholes=True, initializer=initializer,
state_is_tuple=False)
outputs, final_state = tf.nn.dynamic_rnn(
cell, inputs_t, sequence_length=sequence_length,
swap_memory=swap_memory, dtype=tf.float32)

View File

@ -268,7 +268,7 @@ class BasicLSTMCell(RNNCell):
"""
def __init__(self, num_units, forget_bias=1.0, input_size=None,
state_is_tuple=False, activation=tanh):
state_is_tuple=True, activation=tanh):
"""Initialize the basic LSTM cell.
Args:
@ -276,8 +276,8 @@ class BasicLSTMCell(RNNCell):
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if not state_is_tuple:
@ -385,7 +385,7 @@ class LSTMCell(RNNCell):
use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=1, num_proj_shards=1,
forget_bias=1.0, state_is_tuple=False,
forget_bias=1.0, state_is_tuple=True,
activation=tanh):
"""Initialize the parameters for an LSTM cell.
@ -410,8 +410,8 @@ class LSTMCell(RNNCell):
in order to reduce the scale of forgetting at the beginning of
the training.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. This latter behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if not state_is_tuple:
@ -757,14 +757,15 @@ class EmbeddingWrapper(RNNCell):
class MultiRNNCell(RNNCell):
"""RNN cell composed sequentially of multiple simple cells."""
def __init__(self, cells, state_is_tuple=False):
def __init__(self, cells, state_is_tuple=True):
"""Create a RNN cell composed sequentially of a number of RNNCells.
Args:
cells: list of RNNCells that will be composed in this order.
state_is_tuple: If True, accepted and returned states are n-tuples, where
`n = len(cells)`. By default (False), the states are all
concatenated along the column axis.
`n = len(cells)`. If False, the states are all
concatenated along the column axis. This latter behavior will soon be
deprecated.
Raises:
ValueError: if cells is empty (not allowed), or at least one of the cells