STT-tensorflow/tensorflow/python/kernel_tests/rnn_cell_test.py
Eugene Brevdo ae5c66e3c2 Add support for arbitrarily nested tuples for RNN state.
Also fixed a bug in the RNN unit tests.
Change: 123150781
2016-05-24 15:23:33 -07:00

324 lines
14 KiB
Python

# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RNN cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import tensorflow as tf
# TODO(ebrevdo): Remove once _linear is fully deprecated.
# pylint: disable=protected-access
from tensorflow.python.ops.rnn_cell import _linear as linear
# pylint: enable=protected-access
class RNNCellTest(tf.test.TestCase):
def testLinear(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)):
x = tf.zeros([1, 2])
l = linear([x], 2, False)
sess.run([tf.initialize_all_variables()])
res = sess.run([l], {x.name: np.array([[1., 2.]])})
self.assertAllClose(res[0], [[3.0, 3.0]])
# Checks prevent you from accidentally creating a shared function.
with self.assertRaises(ValueError):
l1 = linear([x], 2, False)
# But you can create a new one in a new scope and share the variables.
with tf.variable_scope("l1") as new_scope:
l1 = linear([x], 2, False)
with tf.variable_scope(new_scope, reuse=True):
linear([l1], 2, False)
self.assertEqual(len(tf.trainable_variables()), 2)
def testBasicRNNCell(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
g, _ = tf.nn.rnn_cell.BasicRNNCell(2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
# Smoke test
self.assertAllClose(res[0], [[0.175991, 0.175991]])
with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3]) # Test GRUCell with input_size != num_units.
m = tf.zeros([1, 2])
g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
def testBasicLSTMCell(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 8])
g, out_m = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
m.name: 0.1 * np.ones([1, 8])})
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
expected_mem = np.array([[0.68967271, 0.68967271,
0.44848421, 0.44848421,
0.39897051, 0.39897051,
0.24024698, 0.24024698]])
self.assertAllClose(res[1], expected_mem)
with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3]) # Test BasicLSTMCell with input_size != num_units.
m = tf.zeros([1, 4])
g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
m.name: 0.1 * np.ones([1, 4])})
self.assertEqual(len(res), 2)
def testBasicLSTMCellWithStateTuple(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m0 = tf.zeros([1, 4])
m1 = tf.zeros([1, 4])
cell = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.BasicLSTMCell(2)] * 2, state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m0, out_m1],
{x.name: np.array([[1., 1.]]),
m0.name: 0.1 * np.ones([1, 4]),
m1.name: 0.1 * np.ones([1, 4])})
self.assertEqual(len(res), 3)
# The numbers in results were not calculated, this is just a smoke test.
# Note, however, these values should match the original
# version having state_is_tuple=False.
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
expected_mem0 = np.array([[0.68967271, 0.68967271,
0.44848421, 0.44848421]])
expected_mem1 = np.array([[0.39897051, 0.39897051,
0.24024698, 0.24024698]])
self.assertAllClose(res[1], expected_mem0)
self.assertAllClose(res[2], expected_mem1)
def testLSTMCell(self):
with self.test_session() as sess:
num_units = 8
num_proj = 6
state_size = num_units + num_proj
batch_size = 3
input_size = 2
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([batch_size, input_size])
m = tf.zeros([batch_size, state_size])
output, state = tf.nn.rnn_cell.LSTMCell(
num_units=num_units, num_proj=num_proj, forget_bias=1.0)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([output, state],
{x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
m.name: 0.1 * np.ones((batch_size, state_size))})
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
self.assertEqual(res[0].shape, (batch_size, num_proj))
self.assertEqual(res[1].shape, (batch_size, state_size))
# Different inputs so different outputs and states
for i in range(1, batch_size):
self.assertTrue(
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
self.assertTrue(
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testOutputProjectionWrapper(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 3])
cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(3), 2)
g, new_m = cell(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 3))
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.231907, 0.231907]])
def testInputProjectionWrapper(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 3])
cell = tf.nn.rnn_cell.InputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(3), num_proj=3)
g, new_m = cell(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 3))
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testDropoutWrapper(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 3])
keep = tf.zeros([]) + 1
g, new_m = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.GRUCell(3),
keep, keep)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 3))
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testEmbeddingWrapper(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 1], dtype=tf.int32)
m = tf.zeros([1, 2])
g, new_m = tf.nn.rnn_cell.EmbeddingWrapper(
tf.nn.rnn_cell.GRUCell(2),
embedding_classes=3, embedding_size=2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 2))
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.17139, 0.17139]])
def testMultiRNNCell(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 4])
_, ml = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.GRUCell(2)] * 2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run(ml, {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1, 0.1]])})
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res, [[0.175991, 0.175991,
0.13248, 0.13248]])
def testMultiRNNCellWithStateTuple(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m_bad = tf.zeros([1, 4])
m_good = (tf.zeros([1, 2]), tf.zeros([1, 2]))
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.GRUCell(2)] * 2, state_is_tuple=True)(x, m_bad)
_, ml = tf.nn.rnn_cell.MultiRNNCell(
[tf.nn.rnn_cell.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good)
sess.run([tf.initialize_all_variables()])
res = sess.run(ml, {x.name: np.array([[1., 1.]]),
m_good[0].name: np.array([[0.1, 0.1]]),
m_good[1].name: np.array([[0.1, 0.1]])})
# The numbers in results were not calculated, this is just a
# smoke test. However, these numbers should match those of
# the test testMultiRNNCell.
self.assertAllClose(res[0], [[0.175991, 0.175991]])
self.assertAllClose(res[1], [[0.13248, 0.13248]])
class SlimRNNCellTest(tf.test.TestCase):
def testBasicRNNCell(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
my_cell = functools.partial(basic_rnn_cell, num_units=2)
g, _ = tf.nn.rnn_cell.SlimRNNCell(my_cell)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellMatch(self):
batch_size = 32
input_size = 100
num_units = 10
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inputs = tf.random_uniform((batch_size, input_size))
_, initial_state = basic_rnn_cell(inputs, None, num_units)
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
slim_cell = tf.nn.rnn_cell.SlimRNNCell(my_cell)
slim_outputs, slim_state = slim_cell(inputs, initial_state)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
outputs, state = rnn_cell(inputs, initial_state)
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
self.assertEqual(slim_state.get_shape(), state.get_shape())
sess.run([tf.initialize_all_variables()])
res = sess.run([slim_outputs, slim_state, outputs, state])
self.assertAllClose(res[0], res[2])
self.assertAllClose(res[1], res[3])
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
if inputs is not None:
batch_size = inputs.get_shape()[0]
dtype = inputs.dtype
else:
batch_size = 0
dtype = tf.float32
init_output = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
init_state = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
init_output.set_shape([batch_size, num_units])
init_state.set_shape([batch_size, num_units])
return init_output, init_state
else:
with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"):
output = tf.tanh(linear([inputs, state],
num_units, True))
return output, output
if __name__ == "__main__":
tf.test.main()