Apply a sequence mask for the gradient in ctc_loss_dense.
PiperOrigin-RevId: 295767405 Change-Id: I80a53508288cdc505f876901fde5fa46a7645bca
This commit is contained in:
parent
884a14ac9a
commit
b04371bc95
@ -367,7 +367,8 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
batch_size = 8
|
batch_size = 8
|
||||||
num_labels = 6
|
num_labels = 6
|
||||||
label_length = 5
|
label_length = 5
|
||||||
num_frames = 12
|
minimum_logits_length = 10
|
||||||
|
num_frames = minimum_logits_length + batch_size
|
||||||
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
|
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
|
||||||
labels = random_ops.random_uniform(
|
labels = random_ops.random_uniform(
|
||||||
[batch_size, label_length], minval=1, maxval=num_labels,
|
[batch_size, label_length], minval=1, maxval=num_labels,
|
||||||
@ -379,7 +380,7 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
|
label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
|
||||||
labels *= label_mask
|
labels *= label_mask
|
||||||
|
|
||||||
logit_lengths = [num_frames] * batch_size
|
logit_lengths = math_ops.range(batch_size) + minimum_logits_length
|
||||||
|
|
||||||
ctc_loss = ctc_ops.ctc_loss_dense(
|
ctc_loss = ctc_ops.ctc_loss_dense(
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -410,8 +411,8 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
|
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
*self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
|
*self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
|
||||||
rtol=2e-06,
|
rtol=4e-06,
|
||||||
atol=2e-06)
|
atol=4e-06)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self):
|
def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self):
|
||||||
|
|||||||
@ -658,6 +658,17 @@ def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
|
|||||||
olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
|
olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
|
||||||
|
|
||||||
grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
|
grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
|
||||||
|
|
||||||
|
# Applies the sequence mask for the gradient. It is enough to appply the mask
|
||||||
|
# only for ilabel_log_probs because olabel_log_probs already consider the
|
||||||
|
# mask. However, it is just safe and clean to apply it for the gradient.
|
||||||
|
max_logit_length = _get_dim(logits, 0)
|
||||||
|
logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
|
||||||
|
dtypes.float32)
|
||||||
|
logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
|
||||||
|
logit_mask = array_ops.expand_dims(logit_mask, axis=2)
|
||||||
|
grad *= logit_mask
|
||||||
|
|
||||||
loss = -log_likelihood
|
loss = -log_likelihood
|
||||||
return loss, grad
|
return loss, grad
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user