Added 2 unit tests for BasicLSTMCell to check ValueError. (#9693)
* Added 2 unit tests for BasicLSTMCell to check ValueError. * spelling error * change assertRaisesRegexp => assertRaises * some lint fixes
This commit is contained in:
parent
4be052a5fc
commit
d5956e3e6a
@ -194,6 +194,44 @@ class RNNCellTest(test.TestCase):
|
||||
m.name: 0.1 * np.ones([1, 4])})
|
||||
self.assertEqual(len(res), 2)
|
||||
|
||||
def testBasicLSTMCellDimension0Error(self):
|
||||
"""Tests that dimension 0 in both(x and m) shape must be equal."""
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
num_units = 2
|
||||
state_size = num_units * 2
|
||||
batch_size = 3
|
||||
input_size = 4
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size - 1, state_size])
|
||||
with self.assertRaises(ValueError):
|
||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
sess.run([g, out_m],
|
||||
{x.name: 1 * np.ones([batch_size, input_size]),
|
||||
m.name: 0.1 * np.ones([batch_size - 1, state_size])})
|
||||
|
||||
def testBasicLSTMCellStateSizeError(self):
|
||||
"""Tests that state_size must be num_units * 2."""
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
num_units = 2
|
||||
state_size = num_units * 3 # state_size must be num_units * 2
|
||||
batch_size = 3
|
||||
input_size = 4
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size])
|
||||
with self.assertRaises(ValueError):
|
||||
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
sess.run([g, out_m],
|
||||
{x.name: 1 * np.ones([batch_size, input_size]),
|
||||
m.name: 0.1 * np.ones([batch_size, state_size])})
|
||||
|
||||
def testBasicLSTMCellStateTupleType(self):
|
||||
with self.test_session():
|
||||
with variable_scope.variable_scope(
|
||||
|
Loading…
Reference in New Issue
Block a user