Added 2 unit tests for BasicLSTMCell to check ValueError. (#9693)
* Added 2 unit tests for BasicLSTMCell to check ValueError. * spelling error * change assertRaisesRegexp => assertRaises * some lint fixes
This commit is contained in:
parent
4be052a5fc
commit
d5956e3e6a
@ -194,6 +194,44 @@ class RNNCellTest(test.TestCase):
|
|||||||
m.name: 0.1 * np.ones([1, 4])})
|
m.name: 0.1 * np.ones([1, 4])})
|
||||||
self.assertEqual(len(res), 2)
|
self.assertEqual(len(res), 2)
|
||||||
|
|
||||||
|
def testBasicLSTMCellDimension0Error(self):
|
||||||
|
"""Tests that dimension 0 in both(x and m) shape must be equal."""
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
|
num_units = 2
|
||||||
|
state_size = num_units * 2
|
||||||
|
batch_size = 3
|
||||||
|
input_size = 4
|
||||||
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
|
m = array_ops.zeros([batch_size - 1, state_size])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||||
|
num_units, state_is_tuple=False)(x, m)
|
||||||
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
|
sess.run([g, out_m],
|
||||||
|
{x.name: 1 * np.ones([batch_size, input_size]),
|
||||||
|
m.name: 0.1 * np.ones([batch_size - 1, state_size])})
|
||||||
|
|
||||||
|
def testBasicLSTMCellStateSizeError(self):
|
||||||
|
"""Tests that state_size must be num_units * 2."""
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
|
num_units = 2
|
||||||
|
state_size = num_units * 3 # state_size must be num_units * 2
|
||||||
|
batch_size = 3
|
||||||
|
input_size = 4
|
||||||
|
x = array_ops.zeros([batch_size, input_size])
|
||||||
|
m = array_ops.zeros([batch_size, state_size])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
g, out_m = core_rnn_cell_impl.BasicLSTMCell(
|
||||||
|
num_units, state_is_tuple=False)(x, m)
|
||||||
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
|
sess.run([g, out_m],
|
||||||
|
{x.name: 1 * np.ones([batch_size, input_size]),
|
||||||
|
m.name: 0.1 * np.ones([batch_size, state_size])})
|
||||||
|
|
||||||
def testBasicLSTMCellStateTupleType(self):
|
def testBasicLSTMCellStateTupleType(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
|
Loading…
Reference in New Issue
Block a user