Added 2 unit tests for BasicLSTMCell to check ValueError. (#9693)

* Added 2 unit tests for BasicLSTMCell to check ValueError.

* spelling error

* change assertRaisesRegexp => assertRaises

* some lint fixes
This commit is contained in:
Chris Hoyean Song 2017-05-07 03:03:33 +09:00 committed by Vijay Vasudevan
parent 4be052a5fc
commit d5956e3e6a

View File

@ -194,6 +194,44 @@ class RNNCellTest(test.TestCase):
m.name: 0.1 * np.ones([1, 4])})
self.assertEqual(len(res), 2)
def testBasicLSTMCellDimension0Error(self):
"""Tests that dimension 0 in both(x and m) shape must be equal."""
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
state_size = num_units * 2
batch_size = 3
input_size = 4
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size - 1, state_size])
with self.assertRaises(ValueError):
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m],
{x.name: 1 * np.ones([batch_size, input_size]),
m.name: 0.1 * np.ones([batch_size - 1, state_size])})
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
state_size = num_units * 3 # state_size must be num_units * 2
batch_size = 3
input_size = 4
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
with self.assertRaises(ValueError):
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m],
{x.name: 1 * np.ones([batch_size, input_size]),
m.name: 0.1 * np.ones([batch_size, state_size])})
def testBasicLSTMCellStateTupleType(self):
with self.test_session():
with variable_scope.variable_scope(