From 6f3cc9d368a17646f5838e36be3b1c25bf4534fe Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 28 Apr 2018 12:06:15 -0700 Subject: [PATCH] Pass dtype to constructor in LSTMCell (#18178) * Use float32 in case the dtype is not set in the constructor Signed-off-by: Yong Tang * Add test case for 16228. Signed-off-by: Yong Tang * Add test case where dype is passed explicitly. Signed-off-by: Yong Tang * Fix pylint issue Signed-off-by: Yong Tang * Replace strings to objects to address review feedback. Signed-off-by: Yong Tang --- .../rnn/python/kernel_tests/core_rnn_test.py | 15 +++++++++++++++ tensorflow/python/ops/rnn_cell_impl.py | 6 +++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index de5df912921..ba4933ddf79 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -307,6 +307,21 @@ class LSTMTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) + def testDType(self): + # Test case for GitHub issue 16228 + # Not passing dtype in constructor results in default float32 + lstm = rnn_cell.LSTMCell(10) + input_tensor = array_ops.ones([10, 50]) + lstm.build(input_tensor.get_shape()) + self.assertEqual(lstm._bias.dtype, dtypes.float32_ref) + + # Explicitly pass dtype in constructor + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + lstm = rnn_cell.LSTMCell(10, dtype=dtype) + input_tensor = array_ops.ones([10, 50]) + lstm.build(input_tensor.get_shape()) + self.assertEqual(lstm._bias.dtype, dtype._as_ref) + def testNoProjNoSharding(self): num_units = 3 input_size = 5 diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 86dc053c0fb..67f753485b8 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -785,10 +785,14 @@ class LSTMCell(LayerRNNCell): shape=[input_depth + h_depth, 4 * self._num_units], initializer=self._initializer, partitioner=maybe_partitioner) + if self.dtype is None: + initializer = init_ops.zeros_initializer + else: + initializer = init_ops.zeros_initializer(dtype=self.dtype) self._bias = self.add_variable( _BIAS_VARIABLE_NAME, shape=[4 * self._num_units], - initializer=init_ops.zeros_initializer(dtype=self.dtype)) + initializer=initializer) if self._use_peepholes: self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units], initializer=self._initializer)