** BREAKING CHANGE **
Core RNNCell implementations now use state_is_tuple=True by default This is part of the deprecation process for non-tuple LSTM and MultiRNNCell states. Change: 130059769
This commit is contained in:
parent
fc218ef4d5
commit
bc15695781
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for GridRNN cells."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -27,7 +26,8 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
|
||||
def testGrid2BasicLSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)) as root_scope:
|
||||
with tf.variable_scope(
|
||||
'root', initializer=tf.constant_initializer(0.2)) as root_scope:
|
||||
x = tf.zeros([1, 3])
|
||||
m = tf.zeros([1, 8])
|
||||
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2)
|
||||
@ -38,15 +38,18 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (1, 8))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
|
||||
self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
|
||||
0.72320831, 0.80555487, 0.39102408, 0.42150158]])
|
||||
self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
|
||||
self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
|
||||
0.36617181, 0.72320831, 0.80555487,
|
||||
0.39102408, 0.42150158]])
|
||||
|
||||
# emulate a loop through the input sequence, where we call cell() multiple times
|
||||
# emulate a loop through the input sequence,
|
||||
# where we call cell() multiple times
|
||||
root_scope.reuse_variables()
|
||||
g2, s2 = cell(x, m)
|
||||
self.assertEqual(g2.get_shape(), (1, 2))
|
||||
@ -56,8 +59,9 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[0.58847463, 0.58847463]])
|
||||
self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, 0.58847463,
|
||||
0.97726452, 1.04626071, 0.4927212, 0.51137757]])
|
||||
self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463,
|
||||
0.58847463, 0.97726452, 1.04626071,
|
||||
0.4927212, 0.51137757]])
|
||||
|
||||
def testGrid2BasicLSTMCellTied(self):
|
||||
with self.test_session() as sess:
|
||||
@ -72,27 +76,31 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (1, 8))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
|
||||
self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
|
||||
0.72320831, 0.80555487, 0.39102408, 0.42150158]])
|
||||
self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
|
||||
self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
|
||||
0.36617181, 0.72320831, 0.80555487,
|
||||
0.39102408, 0.42150158]])
|
||||
|
||||
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[0.36703536, 0.36703536]])
|
||||
self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, 0.36703536,
|
||||
0.80941606, 0.87550586, 0.40108523, 0.42199609]])
|
||||
self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536,
|
||||
0.36703536, 0.80941606, 0.87550586,
|
||||
0.40108523, 0.42199609]])
|
||||
|
||||
def testGrid2BasicLSTMCellWithRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
|
||||
x = tf.zeros([1, 3])
|
||||
m = tf.zeros([1, 4])
|
||||
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=False, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(
|
||||
2, tied=False, non_recurrent_fn=tf.nn.relu)
|
||||
self.assertEqual(cell.state_size, 4)
|
||||
|
||||
g, s = cell(x, m)
|
||||
@ -104,11 +112,11 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 4))
|
||||
self.assertAllClose(res[0], [[ 0.31667367, 0.31667367]])
|
||||
self.assertAllClose(res[1], [[ 0.29530135, 0.37520045, 0.17044567, 0.21292259]])
|
||||
self.assertAllClose(res[0], [[0.31667367, 0.31667367]])
|
||||
self.assertAllClose(res[1],
|
||||
[[0.29530135, 0.37520045, 0.17044567, 0.21292259]])
|
||||
|
||||
"""
|
||||
LSTMCell
|
||||
"""LSTMCell
|
||||
"""
|
||||
|
||||
def testGrid2LSTMCell(self):
|
||||
@ -124,20 +132,23 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (1, 8))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
|
||||
1.38917875, 1.49043763, 0.83884692, 0.86036491]])
|
||||
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
|
||||
0.95686918, 1.38917875, 1.49043763,
|
||||
0.83884692, 0.86036491]])
|
||||
|
||||
def testGrid2LSTMCellTied(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 3])
|
||||
m = tf.zeros([1, 8])
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, tied=True, use_peepholes=True)
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(
|
||||
2, tied=True, use_peepholes=True)
|
||||
self.assertEqual(cell.state_size, 8)
|
||||
|
||||
g, s = cell(x, m)
|
||||
@ -145,20 +156,23 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (1, 8))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
|
||||
1.38917875, 1.49043763, 0.83884692, 0.86036491]])
|
||||
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
|
||||
0.95686918, 1.38917875, 1.49043763,
|
||||
0.83884692, 0.86036491]])
|
||||
|
||||
def testGrid2LSTMCellWithRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 3])
|
||||
m = tf.zeros([1, 4])
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(
|
||||
2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
|
||||
self.assertEqual(cell.state_size, 4)
|
||||
|
||||
g, s = cell(x, m)
|
||||
@ -170,11 +184,11 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 4))
|
||||
self.assertAllClose(res[0], [[ 2.1831727, 2.1831727]])
|
||||
self.assertAllClose(res[1], [[ 0.92270052, 1.02325559, 0.66159075, 0.70475441]])
|
||||
self.assertAllClose(res[0], [[2.1831727, 2.1831727]])
|
||||
self.assertAllClose(res[1],
|
||||
[[0.92270052, 1.02325559, 0.66159075, 0.70475441]])
|
||||
|
||||
"""
|
||||
RNNCell
|
||||
"""RNNCell
|
||||
"""
|
||||
|
||||
def testGrid2BasicRNNCell(self):
|
||||
@ -190,14 +204,16 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (2, 4))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
||||
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
||||
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
self.assertEqual(res[1].shape, (2, 4))
|
||||
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
|
||||
[0.99480951, 0.99480951]])
|
||||
self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
|
||||
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
|
||||
self.assertAllClose(res[1],
|
||||
[[0.94685763, 0.94685763, 0.80049908, 0.80049908],
|
||||
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
|
||||
|
||||
def testGrid2BasicRNNCellTied(self):
|
||||
with self.test_session() as sess:
|
||||
@ -212,21 +228,24 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (2, 4))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
||||
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
||||
m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
|
||||
self.assertEqual(res[0].shape, (2, 2))
|
||||
self.assertEqual(res[1].shape, (2, 4))
|
||||
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
|
||||
[0.99480951, 0.99480951]])
|
||||
self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
|
||||
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
|
||||
self.assertAllClose(res[1],
|
||||
[[0.94685763, 0.94685763, 0.80049908, 0.80049908],
|
||||
[0.99480951, 0.99480951, 0.97574311, 0.97574311]])
|
||||
|
||||
def testGrid2BasicRNNCellWithRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 2])
|
||||
m = tf.zeros([1, 2])
|
||||
cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(
|
||||
2, non_recurrent_fn=tf.nn.relu)
|
||||
self.assertEqual(cell.state_size, 2)
|
||||
|
||||
g, s = cell(x, m)
|
||||
@ -234,19 +253,20 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (1, 2))
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g, s], {x: np.array([[1., 1.]]),
|
||||
m: np.array([[0.1, 0.1]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 2))
|
||||
self.assertAllClose(res[0], [[1.80049896, 1.80049896]])
|
||||
self.assertAllClose(res[1], [[0.80049896, 0.80049896]])
|
||||
|
||||
"""
|
||||
1-LSTM
|
||||
"""1-LSTM
|
||||
"""
|
||||
|
||||
def testGrid1LSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)) as root_scope:
|
||||
with tf.variable_scope(
|
||||
'root', initializer=tf.constant_initializer(0.5)) as root_scope:
|
||||
x = tf.zeros([1, 3])
|
||||
m = tf.zeros([1, 4])
|
||||
cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True)
|
||||
@ -262,7 +282,8 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 4))
|
||||
self.assertAllClose(res[0], [[0.91287315, 0.91287315]])
|
||||
self.assertAllClose(res[1], [[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
|
||||
self.assertAllClose(res[1],
|
||||
[[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
|
||||
|
||||
root_scope.reuse_variables()
|
||||
|
||||
@ -276,7 +297,8 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 4))
|
||||
self.assertAllClose(res[0], [[0.9032144, 0.9032144]])
|
||||
self.assertAllClose(res[1], [[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
|
||||
self.assertAllClose(res[1],
|
||||
[[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
|
||||
|
||||
g3, s3 = cell(x2, m)
|
||||
self.assertEqual(g3.get_shape(), (1, 2))
|
||||
@ -287,11 +309,12 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 4))
|
||||
self.assertAllClose(res[0], [[0.92727238, 0.92727238]])
|
||||
self.assertAllClose(res[1], [[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
|
||||
self.assertAllClose(res[1],
|
||||
[[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
|
||||
|
||||
"""3-LSTM
|
||||
"""
|
||||
3-LSTM
|
||||
"""
|
||||
|
||||
def testGrid3LSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
@ -306,18 +329,20 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4]])})
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
|
||||
0.8, -0.1, -0.2, -0.3, -0.4]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 12))
|
||||
|
||||
self.assertAllClose(res[0], [[0.96892911, 0.96892911]])
|
||||
self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, 0.96892911,
|
||||
1.33592629, 1.4373529, 0.80867189, 0.83247656,
|
||||
0.7317788, 0.63205892, 0.56548983, 0.50446129]])
|
||||
self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911,
|
||||
0.96892911, 1.33592629, 1.4373529,
|
||||
0.80867189, 0.83247656, 0.7317788,
|
||||
0.63205892, 0.56548983, 0.50446129]])
|
||||
|
||||
"""Edge cases
|
||||
"""
|
||||
Edge cases
|
||||
"""
|
||||
|
||||
def testGridRNNEdgeCasesLikeRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
@ -325,8 +350,13 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
m = tf.zeros([0, 0])
|
||||
|
||||
# this is equivalent to relu
|
||||
cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=1, input_dims=0, output_dims=0,
|
||||
non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.GridRNNCell(
|
||||
num_units=2,
|
||||
num_dims=1,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
non_recurrent_dims=0,
|
||||
non_recurrent_fn=tf.nn.relu)
|
||||
g, s = cell(x, m)
|
||||
self.assertEqual(g.get_shape(), (3, 2))
|
||||
self.assertEqual(s.get_shape(), (0, 0))
|
||||
@ -340,12 +370,17 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
def testGridRNNEdgeCasesNoOutput(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 2])
|
||||
x = tf.zeros([1, 2])
|
||||
m = tf.zeros([1, 4])
|
||||
|
||||
# This cell produces no output
|
||||
cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=2, input_dims=0, output_dims=None,
|
||||
non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.GridRNNCell(
|
||||
num_units=2,
|
||||
num_dims=2,
|
||||
input_dims=0,
|
||||
output_dims=None,
|
||||
non_recurrent_dims=0,
|
||||
non_recurrent_fn=tf.nn.relu)
|
||||
g, s = cell(x, m)
|
||||
self.assertEqual(g.get_shape(), (0, 0))
|
||||
self.assertEqual(s.get_shape(), (1, 4))
|
||||
@ -356,8 +391,7 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
self.assertEqual(res[0].shape, (0, 0))
|
||||
self.assertEqual(res[1].shape, (1, 4))
|
||||
|
||||
"""
|
||||
Test with tf.nn.rnn
|
||||
"""Test with tf.nn.rnn
|
||||
"""
|
||||
|
||||
def testGrid2LSTMCellWithRNN(self):
|
||||
@ -370,7 +404,9 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units)
|
||||
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
tf.placeholder(
|
||||
tf.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
|
||||
@ -386,8 +422,7 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
input_value = np.ones((batch_size, input_size))
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
|
||||
for v in values:
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
|
||||
@ -398,10 +433,13 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
num_units = 2
|
||||
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.Grid2LSTMCell(
|
||||
num_units=num_units, non_recurrent_fn=tf.nn.relu)
|
||||
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
tf.placeholder(
|
||||
tf.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
|
||||
@ -417,8 +455,7 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
input_value = np.ones((batch_size, input_size))
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
|
||||
for v in values:
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
|
||||
@ -429,10 +466,13 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
num_units = 2
|
||||
|
||||
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
|
||||
cell = tf.contrib.grid_rnn.Grid3LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
|
||||
cell = tf.contrib.grid_rnn.Grid3LSTMCell(
|
||||
num_units=num_units, non_recurrent_fn=tf.nn.relu)
|
||||
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
tf.placeholder(
|
||||
tf.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
|
||||
@ -448,12 +488,10 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
input_value = np.ones((batch_size, input_size))
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
|
||||
for v in values:
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
|
||||
|
||||
def testGrid1LSTMCellWithRNN(self):
|
||||
batch_size = 3
|
||||
input_size = 5
|
||||
@ -464,8 +502,8 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units)
|
||||
|
||||
# for 1-LSTM, we only feed the first step
|
||||
inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))] \
|
||||
+ (max_length - 1) * [tf.zeros([batch_size, input_size])]
|
||||
inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
+ (max_length - 1) * [tf.zeros([batch_size, input_size])])
|
||||
|
||||
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
|
||||
@ -480,10 +518,10 @@ class GridRNNCellTest(tf.test.TestCase):
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
input_value = np.ones((batch_size, input_size))
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
|
||||
for v in values:
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# TODO(b/28879898) Fix all lint issues and clean the code.
|
||||
"""Module for constructing GridRNN cells"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -35,46 +35,67 @@ class GridRNNCell(rnn_cell.RNNCell):
|
||||
|
||||
http://arxiv.org/pdf/1507.01526v3.pdf
|
||||
|
||||
This is the generic implementation of GridRNN. Users can specify arbitrary number of dimensions,
|
||||
This is the generic implementation of GridRNN. Users can specify arbitrary
|
||||
number of dimensions,
|
||||
set some of them to be priority (section 3.2), non-recurrent (section 3.3)
|
||||
and input/output dimensions (section 3.4).
|
||||
Weight sharing can also be specified using the `tied` parameter.
|
||||
Type of recurrent units can be specified via `cell_fn`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, num_dims=1, input_dims=None, output_dims=None, priority_dims=None,
|
||||
non_recurrent_dims=None, tied=False, cell_fn=None, non_recurrent_fn=None):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
num_dims=1,
|
||||
input_dims=None,
|
||||
output_dims=None,
|
||||
priority_dims=None,
|
||||
non_recurrent_dims=None,
|
||||
tied=False,
|
||||
cell_fn=None,
|
||||
non_recurrent_fn=None):
|
||||
"""Initialize the parameters of a Grid RNN cell
|
||||
|
||||
Args:
|
||||
num_units: int, The number of units in all dimensions of this GridRNN cell
|
||||
num_dims: int, Number of dimensions of this grid.
|
||||
input_dims: int or list, List of dimensions which will receive input data.
|
||||
output_dims: int or list, List of dimensions from which the output will be recorded.
|
||||
priority_dims: int or list, List of dimensions to be considered as priority dimensions.
|
||||
output_dims: int or list, List of dimensions from which the output will be
|
||||
recorded.
|
||||
priority_dims: int or list, List of dimensions to be considered as
|
||||
priority dimensions.
|
||||
If None, no dimension is prioritized.
|
||||
non_recurrent_dims: int or list, List of dimensions that are not recurrent.
|
||||
The transfer function for non-recurrent dimensions is specified via `non_recurrent_fn`,
|
||||
non_recurrent_dims: int or list, List of dimensions that are not
|
||||
recurrent.
|
||||
The transfer function for non-recurrent dimensions is specified
|
||||
via `non_recurrent_fn`,
|
||||
which is default to be `tensorflow.nn.relu`.
|
||||
tied: bool, Whether to share the weights among the dimensions of this GridRNN cell.
|
||||
If there are non-recurrent dimensions in the grid, weights are shared between each
|
||||
tied: bool, Whether to share the weights among the dimensions of this
|
||||
GridRNN cell.
|
||||
If there are non-recurrent dimensions in the grid, weights are
|
||||
shared between each
|
||||
group of recurrent and non-recurrent dimensions.
|
||||
cell_fn: function, a function which returns the recurrent cell object. Has to be in the following signature:
|
||||
cell_fn: function, a function which returns the recurrent cell object. Has
|
||||
to be in the following signature:
|
||||
def cell_func(num_units, input_size):
|
||||
# ...
|
||||
|
||||
and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used.
|
||||
non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions
|
||||
and returns an object of type `RNNCell`. If None, LSTMCell with
|
||||
default parameters will be used.
|
||||
non_recurrent_fn: a tensorflow Op that will be the transfer function of
|
||||
the non-recurrent dimensions
|
||||
"""
|
||||
if num_dims < 1:
|
||||
raise ValueError('dims must be >= 1: {}'.format(num_dims))
|
||||
|
||||
self._config = _parse_rnn_config(num_dims, input_dims, output_dims, priority_dims,
|
||||
non_recurrent_dims, non_recurrent_fn or nn.relu, tied, num_units)
|
||||
self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
|
||||
priority_dims, non_recurrent_dims,
|
||||
non_recurrent_fn or nn.relu, tied,
|
||||
num_units)
|
||||
|
||||
cell_input_size = (self._config.num_dims - 1) * num_units
|
||||
if cell_fn is None:
|
||||
self._cell = rnn_cell.LSTMCell(num_units=num_units, input_size=cell_input_size)
|
||||
self._cell = rnn_cell.LSTMCell(
|
||||
num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
|
||||
else:
|
||||
self._cell = cell_fn(num_units, cell_input_size)
|
||||
if not isinstance(self._cell, rnn_cell.RNNCell):
|
||||
@ -100,7 +121,8 @@ class GridRNNCell(rnn_cell.RNNCell):
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, batch x input_size. Or None
|
||||
state: state Tensor, 2D, batch x state_size. Note that state_size = cell_state_size * recurrent_dims
|
||||
state: state Tensor, 2D, batch x state_size. Note that state_size =
|
||||
cell_state_size * recurrent_dims
|
||||
scope: VariableScope for the created subgraph; defaults to "GridRNNCell".
|
||||
|
||||
Returns:
|
||||
@ -112,24 +134,32 @@ class GridRNNCell(rnn_cell.RNNCell):
|
||||
"""
|
||||
state_sz = state.get_shape().as_list()[1]
|
||||
if self.state_size != state_sz:
|
||||
raise ValueError('Actual state size not same as specified: {} vs {}.'.format(state_sz, self.state_size))
|
||||
raise ValueError(
|
||||
'Actual state size not same as specified: {} vs {}.'.format(
|
||||
state_sz, self.state_size))
|
||||
|
||||
conf = self._config
|
||||
dtype = inputs.dtype if inputs is not None else state.dtype
|
||||
|
||||
# c_prev is `m`, and m_prev is `h` in the paper. Keep c and m here for consistency with the codebase
|
||||
# c_prev is `m`, and m_prev is `h` in the paper.
|
||||
# Keep c and m here for consistency with the codebase
|
||||
c_prev = [None] * self._config.num_dims
|
||||
m_prev = [None] * self._config.num_dims
|
||||
cell_output_size = self._cell.state_size - conf.num_units
|
||||
|
||||
# for LSTM : state = memory cell + output, hence cell_output_size > 0
|
||||
# for GRU/RNN: state = output (whose size is equal to _num_units), hence cell_output_size = 0
|
||||
for recurrent_dim, start_idx in zip(self._config.recurrents, range(0, self.state_size, self._cell.state_size)):
|
||||
# for GRU/RNN: state = output (whose size is equal to _num_units),
|
||||
# hence cell_output_size = 0
|
||||
for recurrent_dim, start_idx in zip(self._config.recurrents, range(
|
||||
0, self.state_size, self._cell.state_size)):
|
||||
if cell_output_size > 0:
|
||||
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
|
||||
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx + conf.num_units], [-1, cell_output_size])
|
||||
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
||||
[-1, conf.num_units])
|
||||
m_prev[recurrent_dim] = array_ops.slice(
|
||||
state, [0, start_idx + conf.num_units], [-1, cell_output_size])
|
||||
else:
|
||||
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
|
||||
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
||||
[-1, conf.num_units])
|
||||
|
||||
new_output = [None] * conf.num_dims
|
||||
new_state = [None] * conf.num_dims
|
||||
@ -137,150 +167,212 @@ class GridRNNCell(rnn_cell.RNNCell):
|
||||
with vs.variable_scope(scope or type(self).__name__): # GridRNNCell
|
||||
|
||||
# project input
|
||||
if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(conf.inputs) > 0:
|
||||
if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(
|
||||
conf.inputs) > 0:
|
||||
input_splits = array_ops.split(1, len(conf.inputs), inputs)
|
||||
input_sz = input_splits[0].get_shape().as_list()[1]
|
||||
|
||||
for i, j in enumerate(conf.inputs):
|
||||
input_project_m = vs.get_variable('project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
|
||||
input_project_m = vs.get_variable(
|
||||
'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
|
||||
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
|
||||
|
||||
if cell_output_size > 0:
|
||||
input_project_c = vs.get_variable('project_c_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
|
||||
input_project_c = vs.get_variable(
|
||||
'project_c_{}'.format(j), [input_sz, conf.num_units],
|
||||
dtype=dtype)
|
||||
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
|
||||
|
||||
|
||||
_propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, new_output, new_state, True)
|
||||
_propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, new_state, False)
|
||||
_propagate(conf.non_priority, conf, self._cell, c_prev, m_prev,
|
||||
new_output, new_state, True)
|
||||
_propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output,
|
||||
new_state, False)
|
||||
|
||||
output_tensors = [new_output[i] for i in self._config.outputs]
|
||||
output = array_ops.zeros([0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(1,
|
||||
output_tensors)
|
||||
output = array_ops.zeros(
|
||||
[0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(
|
||||
1, output_tensors)
|
||||
|
||||
state_tensors = [new_state[i] for i in self._config.recurrents]
|
||||
states = array_ops.zeros([0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(1, state_tensors)
|
||||
states = array_ops.zeros(
|
||||
[0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(
|
||||
1, state_tensors)
|
||||
|
||||
return output, states
|
||||
|
||||
|
||||
"""Specialized cells, for convenience
|
||||
"""
|
||||
Specialized cells, for convenience
|
||||
"""
|
||||
|
||||
|
||||
class Grid1BasicRNNCell(GridRNNCell):
|
||||
"""1D BasicRNN cell"""
|
||||
|
||||
def __init__(self, num_units):
|
||||
super(Grid1BasicRNNCell, self).__init__(num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
|
||||
super(Grid1BasicRNNCell, self).__init__(
|
||||
num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
|
||||
|
||||
|
||||
class Grid2BasicRNNCell(GridRNNCell):
|
||||
"""2D BasicRNN cell
|
||||
This creates a 2D cell which receives input and gives output in the first dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
||||
specified.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None):
|
||||
super(Grid2BasicRNNCell, self).__init__(num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
super(Grid2BasicRNNCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
|
||||
class Grid1BasicLSTMCell(GridRNNCell):
|
||||
"""1D BasicLSTM cell"""
|
||||
|
||||
def __init__(self, num_units, forget_bias=1):
|
||||
super(Grid1BasicLSTMCell, self).__init__(num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(num_units=n,
|
||||
forget_bias=forget_bias, input_size=i))
|
||||
super(Grid1BasicLSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
|
||||
num_units=n,
|
||||
forget_bias=forget_bias, input_size=i,
|
||||
state_is_tuple=False))
|
||||
|
||||
|
||||
class Grid2BasicLSTMCell(GridRNNCell):
|
||||
"""2D BasicLSTM cell
|
||||
This creates a 2D cell which receives input and gives output in the first dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
||||
specified.
|
||||
"""
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None, forget_bias=1):
|
||||
super(Grid2BasicLSTMCell, self).__init__(num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
|
||||
num_units=n, forget_bias=forget_bias, input_size=i),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
def __init__(self,
|
||||
num_units,
|
||||
tied=False,
|
||||
non_recurrent_fn=None,
|
||||
forget_bias=1):
|
||||
super(Grid2BasicLSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
|
||||
num_units=n, forget_bias=forget_bias, input_size=i,
|
||||
state_is_tuple=False),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
|
||||
class Grid1LSTMCell(GridRNNCell):
|
||||
"""1D LSTM cell
|
||||
This is different from Grid1BasicLSTMCell because it gives options to specify the forget bias and enabling peepholes
|
||||
|
||||
This is different from Grid1BasicLSTMCell because it gives options to
|
||||
specify the forget bias and enabling peepholes
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, use_peepholes=False, forget_bias=1.0):
|
||||
super(Grid1LSTMCell, self).__init__(num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0,
|
||||
cell_fn=lambda n, i: rnn_cell.LSTMCell(
|
||||
num_units=n, input_size=i, use_peepholes=use_peepholes,
|
||||
forget_bias=forget_bias))
|
||||
super(Grid1LSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0,
|
||||
cell_fn=lambda n, i: rnn_cell.LSTMCell(
|
||||
num_units=n, input_size=i, use_peepholes=use_peepholes,
|
||||
forget_bias=forget_bias, state_is_tuple=False))
|
||||
|
||||
|
||||
class Grid2LSTMCell(GridRNNCell):
|
||||
"""2D LSTM cell
|
||||
This creates a 2D cell which receives input and gives output in the first dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
||||
specified.
|
||||
"""
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
|
||||
use_peepholes=False, forget_bias=1.0):
|
||||
super(Grid2LSTMCell, self).__init__(num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.LSTMCell(
|
||||
num_units=n, input_size=i, forget_bias=forget_bias,
|
||||
use_peepholes=use_peepholes),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
def __init__(self,
|
||||
num_units,
|
||||
tied=False,
|
||||
non_recurrent_fn=None,
|
||||
use_peepholes=False,
|
||||
forget_bias=1.0):
|
||||
super(Grid2LSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.LSTMCell(
|
||||
num_units=n, input_size=i, forget_bias=forget_bias,
|
||||
use_peepholes=use_peepholes, state_is_tuple=False),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
|
||||
class Grid3LSTMCell(GridRNNCell):
|
||||
"""3D BasicLSTM cell
|
||||
This creates a 2D cell which receives input and gives output in the first dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
||||
specified.
|
||||
The second and third dimensions are LSTM.
|
||||
"""
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
|
||||
use_peepholes=False, forget_bias=1.0):
|
||||
super(Grid3LSTMCell, self).__init__(num_units=num_units, num_dims=3,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.LSTMCell(
|
||||
num_units=n, input_size=i, forget_bias=forget_bias,
|
||||
use_peepholes=use_peepholes),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
def __init__(self,
|
||||
num_units,
|
||||
tied=False,
|
||||
non_recurrent_fn=None,
|
||||
use_peepholes=False,
|
||||
forget_bias=1.0):
|
||||
super(Grid3LSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=3,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.LSTMCell(
|
||||
num_units=n, input_size=i, forget_bias=forget_bias,
|
||||
use_peepholes=use_peepholes, state_is_tuple=False),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
|
||||
class Grid2GRUCell(GridRNNCell):
|
||||
"""2D LSTM cell
|
||||
This creates a 2D cell which receives input and gives output in the first dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
||||
specified.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None):
|
||||
super(Grid2GRUCell, self).__init__(num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
super(Grid2GRUCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
|
||||
non_recurrent_fn=non_recurrent_fn)
|
||||
|
||||
"""
|
||||
Helpers
|
||||
|
||||
"""Helpers
|
||||
"""
|
||||
|
||||
_GridRNNDimension = namedtuple('_GridRNNDimension', ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
|
||||
_GridRNNDimension = namedtuple(
|
||||
'_GridRNNDimension',
|
||||
['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
|
||||
|
||||
_GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims',
|
||||
'inputs', 'outputs', 'recurrents',
|
||||
'priority', 'non_priority', 'tied', 'num_units'])
|
||||
_GridRNNConfig = namedtuple('_GridRNNConfig',
|
||||
['num_dims', 'dims', 'inputs', 'outputs',
|
||||
'recurrents', 'priority', 'non_priority', 'tied',
|
||||
'num_units'])
|
||||
|
||||
|
||||
def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims,
|
||||
non_recurrent_fn, tied, num_units):
|
||||
def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
|
||||
ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
|
||||
|
||||
def check_dim_list(ls):
|
||||
if ls is None:
|
||||
ls = []
|
||||
@ -288,7 +380,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
|
||||
ls = [ls]
|
||||
ls = sorted(set(ls))
|
||||
if any(_ < 0 or _ >= num_dims for _ in ls):
|
||||
raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, num_dims))
|
||||
raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls,
|
||||
num_dims))
|
||||
return ls
|
||||
|
||||
input_dims = check_dim_list(ls_input_dims)
|
||||
@ -298,42 +391,58 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
|
||||
|
||||
rnn_dims = []
|
||||
for i in range(num_dims):
|
||||
rnn_dims.append(_GridRNNDimension(idx=i, is_input=(i in input_dims), is_output=(i in output_dims),
|
||||
is_priority=(i in priority_dims),
|
||||
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else None))
|
||||
return _GridRNNConfig(num_dims=num_dims, dims=rnn_dims, inputs=input_dims, outputs=output_dims,
|
||||
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
|
||||
priority=priority_dims,
|
||||
non_priority=[x for x in range(num_dims) if x not in priority_dims],
|
||||
tied=tied, num_units=num_units)
|
||||
rnn_dims.append(
|
||||
_GridRNNDimension(
|
||||
idx=i,
|
||||
is_input=(i in input_dims),
|
||||
is_output=(i in output_dims),
|
||||
is_priority=(i in priority_dims),
|
||||
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
|
||||
None))
|
||||
return _GridRNNConfig(
|
||||
num_dims=num_dims,
|
||||
dims=rnn_dims,
|
||||
inputs=input_dims,
|
||||
outputs=output_dims,
|
||||
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
|
||||
priority=priority_dims,
|
||||
non_priority=[x for x in range(num_dims) if x not in priority_dims],
|
||||
tied=tied,
|
||||
num_units=num_units)
|
||||
|
||||
|
||||
def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, first_call):
|
||||
"""
|
||||
Propagates through all the cells in dim_indices dimensions.
|
||||
def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state,
|
||||
first_call):
|
||||
"""Propagates through all the cells in dim_indices dimensions.
|
||||
"""
|
||||
if len(dim_indices) == 0:
|
||||
return
|
||||
|
||||
# Because of the way RNNCells are implemented, we take the last dimension (H_{N-1}) out
|
||||
# and feed it as the state of the RNN cell (in `last_dim_output`)
|
||||
# Because of the way RNNCells are implemented, we take the last dimension
|
||||
# (H_{N-1}) out and feed it as the state of the RNN cell
|
||||
# (in `last_dim_output`).
|
||||
# The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs`
|
||||
if conf.num_dims > 1:
|
||||
ls_cell_inputs = [None] * (conf.num_dims - 1)
|
||||
for d in conf.dims[:-1]:
|
||||
ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[d.idx] is not None else m_prev[d.idx]
|
||||
ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[
|
||||
d.idx] is not None else m_prev[d.idx]
|
||||
cell_inputs = array_ops.concat(1, ls_cell_inputs)
|
||||
else:
|
||||
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype)
|
||||
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
|
||||
m_prev[0].dtype)
|
||||
|
||||
last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1]
|
||||
|
||||
for i in dim_indices:
|
||||
d = conf.dims[i]
|
||||
if d.non_recurrent_fn:
|
||||
linear_args = array_ops.concat(1, [cell_inputs, last_dim_output]) if conf.num_dims > 1 else last_dim_output
|
||||
with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)):
|
||||
if conf.tied and not(first_call and i == dim_indices[0]):
|
||||
linear_args = array_ops.concat(
|
||||
1, [cell_inputs, last_dim_output
|
||||
]) if conf.num_dims > 1 else last_dim_output
|
||||
with vs.variable_scope('non_recurrent' if conf.tied else
|
||||
'non_recurrent/cell_{}'.format(i)):
|
||||
if conf.tied and not (first_call and i == dim_indices[0]):
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
new_output[d.idx] = layers.legacy_fully_connected(
|
||||
linear_args,
|
||||
@ -348,7 +457,8 @@ def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, f
|
||||
# for GRU/RNN, the state is just the previous output
|
||||
cell_state = last_dim_output
|
||||
|
||||
with vs.variable_scope('recurrent' if conf.tied else 'recurrent/cell_{}'.format(i)):
|
||||
with vs.variable_scope('recurrent' if conf.tied else
|
||||
'recurrent/cell_{}'.format(i)):
|
||||
if conf.tied and not (first_call and i == dim_indices[0]):
|
||||
vs.get_variable_scope().reuse_variables()
|
||||
new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)
|
||||
|
@ -89,7 +89,8 @@ class RNNCellTest(tf.test.TestCase):
|
||||
x = tf.zeros([1, 2])
|
||||
m = tf.zeros([1, 8])
|
||||
g, out_m = tf.nn.rnn_cell.MultiRNNCell(
|
||||
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2)(x, m)
|
||||
[tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)] * 2,
|
||||
state_is_tuple=False)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
|
||||
m.name: 0.1 * np.ones([1, 8])})
|
||||
@ -104,7 +105,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 3]) # Test BasicLSTMCell with input_size != num_units.
|
||||
m = tf.zeros([1, 4])
|
||||
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
|
||||
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
|
||||
m.name: 0.1 * np.ones([1, 4])})
|
||||
@ -117,7 +118,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
m0 = (tf.zeros([1, 2]),) * 2
|
||||
m1 = (tf.zeros([1, 2]),) * 2
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell(
|
||||
[tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)] * 2,
|
||||
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2,
|
||||
state_is_tuple=True)
|
||||
self.assertTrue(isinstance(cell.state_size, tuple))
|
||||
self.assertTrue(isinstance(cell.state_size[0],
|
||||
@ -153,7 +154,8 @@ class RNNCellTest(tf.test.TestCase):
|
||||
m0 = tf.zeros([1, 4])
|
||||
m1 = tf.zeros([1, 4])
|
||||
cell = tf.nn.rnn_cell.MultiRNNCell(
|
||||
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2, state_is_tuple=True)
|
||||
[tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=False)] * 2,
|
||||
state_is_tuple=True)
|
||||
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g, out_m0, out_m1],
|
||||
@ -183,7 +185,8 @@ class RNNCellTest(tf.test.TestCase):
|
||||
x = tf.zeros([batch_size, input_size])
|
||||
m = tf.zeros([batch_size, state_size])
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=num_units, num_proj=num_proj, forget_bias=1.0)
|
||||
num_units=num_units, num_proj=num_proj, forget_bias=1.0,
|
||||
state_is_tuple=False)
|
||||
output, state = cell(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([output, state],
|
||||
@ -286,7 +289,7 @@ class RNNCellTest(tf.test.TestCase):
|
||||
x = tf.zeros([1, 2])
|
||||
m = tf.zeros([1, 4])
|
||||
_, ml = tf.nn.rnn_cell.MultiRNNCell(
|
||||
[tf.nn.rnn_cell.GRUCell(2)] * 2)(x, m)
|
||||
[tf.nn.rnn_cell.GRUCell(2)] * 2, state_is_tuple=False)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run(ml, {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1, 0.1, 0.1]])})
|
||||
|
@ -363,7 +363,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
max_length = 8
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
@ -383,7 +384,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer)
|
||||
num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
@ -408,7 +410,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=False, initializer=initializer)
|
||||
num_units, use_peepholes=False, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
with tf.variable_scope("share_scope"):
|
||||
@ -525,7 +528,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
num_proj=num_proj, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
self.assertEqual(len(outputs), len(inputs))
|
||||
|
||||
@ -546,7 +550,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell_notuple = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
num_proj=num_proj, initializer=initializer, state_is_tuple=False)
|
||||
cell_tuple = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer, state_is_tuple=True)
|
||||
@ -596,7 +600,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
num_proj_shards=num_proj_shards,
|
||||
initializer=initializer)
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
|
||||
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
|
||||
@ -625,7 +630,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
num_proj_shards=num_proj_shards,
|
||||
initializer=initializer)
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
@ -649,7 +655,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
num_proj_shards=num_proj_shards,
|
||||
initializer=initializer)
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
|
||||
outputs, _ = tf.nn.rnn(
|
||||
cell, inputs, initial_state=cell.zero_state(batch_size, tf.float64))
|
||||
@ -681,11 +688,13 @@ class LSTMTest(tf.test.TestCase):
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
num_unit_shards=num_unit_shards,
|
||||
num_proj_shards=num_proj_shards)
|
||||
num_proj_shards=num_proj_shards,
|
||||
state_is_tuple=False)
|
||||
|
||||
cell_shard = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
initializer=initializer, num_proj=num_proj,
|
||||
state_is_tuple=False)
|
||||
|
||||
with tf.variable_scope("noshard_scope"):
|
||||
outputs_noshard, state_noshard = tf.nn.rnn(
|
||||
@ -734,7 +743,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
num_proj=num_proj,
|
||||
num_unit_shards=num_unit_shards,
|
||||
num_proj_shards=num_proj_shards,
|
||||
initializer=initializer)
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
dropout_cell = tf.nn.rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
|
||||
|
||||
outputs, state = tf.nn.rnn(
|
||||
@ -766,10 +776,12 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
num_proj=num_proj, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
cell_d = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer_d)
|
||||
num_proj=num_proj, initializer=initializer_d,
|
||||
state_is_tuple=False)
|
||||
|
||||
with tf.variable_scope("share_scope"):
|
||||
outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
|
||||
@ -805,7 +817,8 @@ class LSTMTest(tf.test.TestCase):
|
||||
tf.placeholder(tf.float32, shape=(None, input_size))]
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
num_proj=num_proj, initializer=initializer)
|
||||
num_proj=num_proj, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
|
||||
with tf.name_scope("scope0"):
|
||||
with tf.variable_scope("share_scope"):
|
||||
@ -947,7 +960,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
initializer=initializer, num_proj=num_proj, state_is_tuple=False)
|
||||
|
||||
with tf.variable_scope("dynamic_scope"):
|
||||
outputs_static, state_static = tf.nn.rnn(
|
||||
@ -1002,7 +1015,7 @@ class LSTMTest(tf.test.TestCase):
|
||||
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=True,
|
||||
initializer=initializer, num_proj=num_proj)
|
||||
initializer=initializer, num_proj=num_proj, state_is_tuple=False)
|
||||
|
||||
with tf.variable_scope("dynamic_scope"):
|
||||
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
|
||||
@ -1141,10 +1154,12 @@ class BidirectionalRNNTest(tf.test.TestCase):
|
||||
sequence_length = tf.placeholder(tf.int64) if use_sequence_length else None
|
||||
cell_fw = tf.nn.rnn_cell.LSTMCell(num_units,
|
||||
input_size,
|
||||
initializer=initializer)
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
cell_bw = tf.nn.rnn_cell.LSTMCell(num_units,
|
||||
input_size,
|
||||
initializer=initializer)
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(
|
||||
tf.float32,
|
||||
@ -1889,7 +1904,8 @@ class StateSaverRNNTest(tf.test.TestCase):
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units, use_peepholes=False, initializer=initializer)
|
||||
num_units, use_peepholes=False, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
|
||||
return tf.nn.state_saving_rnn(
|
||||
@ -1906,7 +1922,8 @@ def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
|
||||
(_, input_size) = inputs_list_t[0].get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
outputs, final_state = tf.nn.rnn(
|
||||
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
|
||||
@ -1920,7 +1937,8 @@ def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
|
||||
(unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
outputs, final_state = tf.nn.dynamic_rnn(
|
||||
cell, inputs_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
|
||||
@ -2023,7 +2041,8 @@ def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
|
||||
(_, input_size) = inputs_list_t[0].get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
outputs, final_state = tf.nn.rnn(
|
||||
cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
|
||||
|
||||
@ -2132,7 +2151,8 @@ def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length,
|
||||
(unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
|
||||
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
|
||||
cell = tf.nn.rnn_cell.LSTMCell(
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer)
|
||||
num_units=input_size, use_peepholes=True, initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
outputs, final_state = tf.nn.dynamic_rnn(
|
||||
cell, inputs_t, sequence_length=sequence_length,
|
||||
swap_memory=swap_memory, dtype=tf.float32)
|
||||
|
@ -268,7 +268,7 @@ class BasicLSTMCell(RNNCell):
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, forget_bias=1.0, input_size=None,
|
||||
state_is_tuple=False, activation=tanh):
|
||||
state_is_tuple=True, activation=tanh):
|
||||
"""Initialize the basic LSTM cell.
|
||||
|
||||
Args:
|
||||
@ -276,8 +276,8 @@ class BasicLSTMCell(RNNCell):
|
||||
forget_bias: float, The bias added to forget gates (see above).
|
||||
input_size: Deprecated and unused.
|
||||
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||
the `c_state` and `m_state`. By default (False), they are concatenated
|
||||
along the column axis. This default behavior will soon be deprecated.
|
||||
the `c_state` and `m_state`. If False, they are concatenated
|
||||
along the column axis. The latter behavior will soon be deprecated.
|
||||
activation: Activation function of the inner states.
|
||||
"""
|
||||
if not state_is_tuple:
|
||||
@ -385,7 +385,7 @@ class LSTMCell(RNNCell):
|
||||
use_peepholes=False, cell_clip=None,
|
||||
initializer=None, num_proj=None, proj_clip=None,
|
||||
num_unit_shards=1, num_proj_shards=1,
|
||||
forget_bias=1.0, state_is_tuple=False,
|
||||
forget_bias=1.0, state_is_tuple=True,
|
||||
activation=tanh):
|
||||
"""Initialize the parameters for an LSTM cell.
|
||||
|
||||
@ -410,8 +410,8 @@ class LSTMCell(RNNCell):
|
||||
in order to reduce the scale of forgetting at the beginning of
|
||||
the training.
|
||||
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||
the `c_state` and `m_state`. By default (False), they are concatenated
|
||||
along the column axis. This default behavior will soon be deprecated.
|
||||
the `c_state` and `m_state`. If False, they are concatenated
|
||||
along the column axis. This latter behavior will soon be deprecated.
|
||||
activation: Activation function of the inner states.
|
||||
"""
|
||||
if not state_is_tuple:
|
||||
@ -757,14 +757,15 @@ class EmbeddingWrapper(RNNCell):
|
||||
class MultiRNNCell(RNNCell):
|
||||
"""RNN cell composed sequentially of multiple simple cells."""
|
||||
|
||||
def __init__(self, cells, state_is_tuple=False):
|
||||
def __init__(self, cells, state_is_tuple=True):
|
||||
"""Create a RNN cell composed sequentially of a number of RNNCells.
|
||||
|
||||
Args:
|
||||
cells: list of RNNCells that will be composed in this order.
|
||||
state_is_tuple: If True, accepted and returned states are n-tuples, where
|
||||
`n = len(cells)`. By default (False), the states are all
|
||||
concatenated along the column axis.
|
||||
`n = len(cells)`. If False, the states are all
|
||||
concatenated along the column axis. This latter behavior will soon be
|
||||
deprecated.
|
||||
|
||||
Raises:
|
||||
ValueError: if cells is empty (not allowed), or at least one of the cells
|
||||
|
Loading…
x
Reference in New Issue
Block a user