Fix small dtype inference error in ctc loss op test.
Change: 115706184
This commit is contained in:
parent
8d7f58d44d
commit
6390a06d80
@ -181,7 +181,7 @@ class CTCLossTest(tf.test.TestCase):
|
||||
for t in range(5)] + 2 * [np.nan*np.ones((2, depth), np.float32)]
|
||||
|
||||
# convert inputs into [max_time x batch_size x depth tensor] Tensor
|
||||
inputs = np.asarray(inputs)
|
||||
inputs = np.asarray(inputs, dtype=np.float32)
|
||||
|
||||
# len batch_size array of label vectors
|
||||
labels = SimpleSparseTensorFrom([targets_0, targets_1])
|
||||
@ -198,7 +198,7 @@ class CTCLossTest(tf.test.TestCase):
|
||||
for t in range(5)] + 2 * [np.zeros((2, depth), np.float32)]
|
||||
|
||||
# convert grad_truth into [max_time x batch_size x depth] Tensor
|
||||
grad_truth = np.asarray(grad_truth)
|
||||
grad_truth = np.asarray(grad_truth, dtype=np.float32)
|
||||
|
||||
self._testCTCLoss(inputs, seq_lens, labels, loss_truth, grad_truth)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user