Apply a sequence mask for the gradient in ctc_loss_dense.

PiperOrigin-RevId: 295767405
Change-Id: I80a53508288cdc505f876901fde5fa46a7645bca
This commit is contained in:
A. Unique TensorFlower 2020-02-18 10:37:24 -08:00 committed by TensorFlower Gardener
parent 884a14ac9a
commit b04371bc95
2 changed files with 16 additions and 4 deletions

View File

@ -367,7 +367,8 @@ class CTCLossTestV2(test.TestCase):
batch_size = 8
num_labels = 6
label_length = 5
num_frames = 12
minimum_logits_length = 10
num_frames = minimum_logits_length + batch_size
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
labels = random_ops.random_uniform(
[batch_size, label_length], minval=1, maxval=num_labels,
@ -379,7 +380,7 @@ class CTCLossTestV2(test.TestCase):
label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
labels *= label_mask
logit_lengths = [num_frames] * batch_size
logit_lengths = math_ops.range(batch_size) + minimum_logits_length
ctc_loss = ctc_ops.ctc_loss_dense(
labels=labels,
@ -410,8 +411,8 @@ class CTCLossTestV2(test.TestCase):
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
self.assertAllClose(
*self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
rtol=2e-06,
atol=2e-06)
rtol=4e-06,
atol=4e-06)
@test_util.run_v1_only("b/120545219")
def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self):

View File

@ -658,6 +658,17 @@ def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
# Applies the sequence mask for the gradient. It is enough to appply the mask
# only for ilabel_log_probs because olabel_log_probs already consider the
# mask. However, it is just safe and clean to apply it for the gradient.
max_logit_length = _get_dim(logits, 0)
logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
dtypes.float32)
logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
logit_mask = array_ops.expand_dims(logit_mask, axis=2)
grad *= logit_mask
loss = -log_likelihood
return loss, grad