diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index f4589e3d9e1..89ad0fcd753 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -194,6 +194,44 @@ class RNNCellTest(test.TestCase): m.name: 0.1 * np.ones([1, 4])}) self.assertEqual(len(res), 2) + def testBasicLSTMCellDimension0Error(self): + """Tests that dimension 0 in both(x and m) shape must be equal.""" + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + num_units = 2 + state_size = num_units * 2 + batch_size = 3 + input_size = 4 + x = array_ops.zeros([batch_size, input_size]) + m = array_ops.zeros([batch_size - 1, state_size]) + with self.assertRaises(ValueError): + g, out_m = core_rnn_cell_impl.BasicLSTMCell( + num_units, state_is_tuple=False)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + sess.run([g, out_m], + {x.name: 1 * np.ones([batch_size, input_size]), + m.name: 0.1 * np.ones([batch_size - 1, state_size])}) + + def testBasicLSTMCellStateSizeError(self): + """Tests that state_size must be num_units * 2.""" + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + num_units = 2 + state_size = num_units * 3 # state_size must be num_units * 2 + batch_size = 3 + input_size = 4 + x = array_ops.zeros([batch_size, input_size]) + m = array_ops.zeros([batch_size, state_size]) + with self.assertRaises(ValueError): + g, out_m = core_rnn_cell_impl.BasicLSTMCell( + num_units, state_is_tuple=False)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + sess.run([g, out_m], + {x.name: 1 * np.ones([batch_size, input_size]), + m.name: 0.1 * np.ones([batch_size, state_size])}) + def testBasicLSTMCellStateTupleType(self): with self.test_session(): with variable_scope.variable_scope(