Pass dtype to constructor in LSTMCell (#18178)

* Use float32 in case the dtype is not set in the constructor

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test case for 16228.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test case where dype is passed explicitly.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix pylint issue

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Replace strings to objects to address review feedback.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2018-04-28 12:06:15 -07:00 committed by drpngx
parent 17cb3cdd30
commit 6f3cc9d368
2 changed files with 20 additions and 1 deletions

View File

@ -307,6 +307,21 @@ class LSTMTest(test.TestCase):
self._seed = 23489 self._seed = 23489
np.random.seed(self._seed) np.random.seed(self._seed)
def testDType(self):
# Test case for GitHub issue 16228
# Not passing dtype in constructor results in default float32
lstm = rnn_cell.LSTMCell(10)
input_tensor = array_ops.ones([10, 50])
lstm.build(input_tensor.get_shape())
self.assertEqual(lstm._bias.dtype, dtypes.float32_ref)
# Explicitly pass dtype in constructor
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
lstm = rnn_cell.LSTMCell(10, dtype=dtype)
input_tensor = array_ops.ones([10, 50])
lstm.build(input_tensor.get_shape())
self.assertEqual(lstm._bias.dtype, dtype._as_ref)
def testNoProjNoSharding(self): def testNoProjNoSharding(self):
num_units = 3 num_units = 3
input_size = 5 input_size = 5

View File

@ -785,10 +785,14 @@ class LSTMCell(LayerRNNCell):
shape=[input_depth + h_depth, 4 * self._num_units], shape=[input_depth + h_depth, 4 * self._num_units],
initializer=self._initializer, initializer=self._initializer,
partitioner=maybe_partitioner) partitioner=maybe_partitioner)
if self.dtype is None:
initializer = init_ops.zeros_initializer
else:
initializer = init_ops.zeros_initializer(dtype=self.dtype)
self._bias = self.add_variable( self._bias = self.add_variable(
_BIAS_VARIABLE_NAME, _BIAS_VARIABLE_NAME,
shape=[4 * self._num_units], shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype)) initializer=initializer)
if self._use_peepholes: if self._use_peepholes:
self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units], self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
initializer=self._initializer) initializer=self._initializer)