Pass dtype to constructor in LSTMCell (#18178)
* Use float32 in case the dtype is not set in the constructor Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for 16228. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case where dype is passed explicitly. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix pylint issue Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Replace strings to objects to address review feedback. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
17cb3cdd30
commit
6f3cc9d368
@ -307,6 +307,21 @@ class LSTMTest(test.TestCase):
|
|||||||
self._seed = 23489
|
self._seed = 23489
|
||||||
np.random.seed(self._seed)
|
np.random.seed(self._seed)
|
||||||
|
|
||||||
|
def testDType(self):
|
||||||
|
# Test case for GitHub issue 16228
|
||||||
|
# Not passing dtype in constructor results in default float32
|
||||||
|
lstm = rnn_cell.LSTMCell(10)
|
||||||
|
input_tensor = array_ops.ones([10, 50])
|
||||||
|
lstm.build(input_tensor.get_shape())
|
||||||
|
self.assertEqual(lstm._bias.dtype, dtypes.float32_ref)
|
||||||
|
|
||||||
|
# Explicitly pass dtype in constructor
|
||||||
|
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||||
|
lstm = rnn_cell.LSTMCell(10, dtype=dtype)
|
||||||
|
input_tensor = array_ops.ones([10, 50])
|
||||||
|
lstm.build(input_tensor.get_shape())
|
||||||
|
self.assertEqual(lstm._bias.dtype, dtype._as_ref)
|
||||||
|
|
||||||
def testNoProjNoSharding(self):
|
def testNoProjNoSharding(self):
|
||||||
num_units = 3
|
num_units = 3
|
||||||
input_size = 5
|
input_size = 5
|
||||||
|
@ -785,10 +785,14 @@ class LSTMCell(LayerRNNCell):
|
|||||||
shape=[input_depth + h_depth, 4 * self._num_units],
|
shape=[input_depth + h_depth, 4 * self._num_units],
|
||||||
initializer=self._initializer,
|
initializer=self._initializer,
|
||||||
partitioner=maybe_partitioner)
|
partitioner=maybe_partitioner)
|
||||||
|
if self.dtype is None:
|
||||||
|
initializer = init_ops.zeros_initializer
|
||||||
|
else:
|
||||||
|
initializer = init_ops.zeros_initializer(dtype=self.dtype)
|
||||||
self._bias = self.add_variable(
|
self._bias = self.add_variable(
|
||||||
_BIAS_VARIABLE_NAME,
|
_BIAS_VARIABLE_NAME,
|
||||||
shape=[4 * self._num_units],
|
shape=[4 * self._num_units],
|
||||||
initializer=init_ops.zeros_initializer(dtype=self.dtype))
|
initializer=initializer)
|
||||||
if self._use_peepholes:
|
if self._use_peepholes:
|
||||||
self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
|
self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
|
||||||
initializer=self._initializer)
|
initializer=self._initializer)
|
||||||
|
Loading…
Reference in New Issue
Block a user