From d5956e3e6a30ec7c81a36d0ab82656e865716f84 Mon Sep 17 00:00:00 2001
From: Chris Hoyean Song <sjhshy@gmail.com>
Date: Sun, 7 May 2017 03:03:33 +0900
Subject: [PATCH] Added 2 unit tests for BasicLSTMCell to check ValueError.
 (#9693)

* Added 2 unit tests for BasicLSTMCell to check ValueError.

* spelling error

* change assertRaisesRegexp => assertRaises

* some lint fixes
---
 .../python/kernel_tests/core_rnn_cell_test.py | 38 +++++++++++++++++++
 1 file changed, 38 insertions(+)

diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index f4589e3d9e1..89ad0fcd753 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -194,6 +194,44 @@ class RNNCellTest(test.TestCase):
              m.name: 0.1 * np.ones([1, 4])})
         self.assertEqual(len(res), 2)
 
+  def testBasicLSTMCellDimension0Error(self):
+    """Tests that dimension 0 in both(x and m) shape must be equal."""
+    with self.test_session() as sess:
+      with variable_scope.variable_scope(
+          "root", initializer=init_ops.constant_initializer(0.5)):
+        num_units = 2
+        state_size = num_units * 2
+        batch_size = 3
+        input_size = 4
+        x = array_ops.zeros([batch_size, input_size])
+        m = array_ops.zeros([batch_size - 1, state_size])
+        with self.assertRaises(ValueError):
+          g, out_m = core_rnn_cell_impl.BasicLSTMCell(
+              num_units, state_is_tuple=False)(x, m)
+          sess.run([variables_lib.global_variables_initializer()])
+          sess.run([g, out_m],
+                   {x.name: 1 * np.ones([batch_size, input_size]),
+               m.name: 0.1 * np.ones([batch_size - 1, state_size])})
+
+  def testBasicLSTMCellStateSizeError(self):
+    """Tests that state_size must be num_units * 2."""
+    with self.test_session() as sess:
+      with variable_scope.variable_scope(
+          "root", initializer=init_ops.constant_initializer(0.5)):
+        num_units = 2
+        state_size = num_units * 3 # state_size must be num_units * 2
+        batch_size = 3
+        input_size = 4
+        x = array_ops.zeros([batch_size, input_size])
+        m = array_ops.zeros([batch_size, state_size])
+        with self.assertRaises(ValueError):
+          g, out_m = core_rnn_cell_impl.BasicLSTMCell(
+              num_units, state_is_tuple=False)(x, m)
+          sess.run([variables_lib.global_variables_initializer()])
+          sess.run([g, out_m],
+                   {x.name: 1 * np.ones([batch_size, input_size]),
+                    m.name: 0.1 * np.ones([batch_size, state_size])})
+
   def testBasicLSTMCellStateTupleType(self):
     with self.test_session():
       with variable_scope.variable_scope(