From b04371bc952f9f9668e862d82db71651fdef8dc6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Feb 2020 10:37:24 -0800 Subject: [PATCH] Apply a sequence mask for the gradient in ctc_loss_dense. PiperOrigin-RevId: 295767405 Change-Id: I80a53508288cdc505f876901fde5fa46a7645bca --- tensorflow/python/kernel_tests/ctc_loss_op_test.py | 9 +++++---- tensorflow/python/ops/ctc_ops.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py index 036cd8ed648..e7f1f8a5e85 100644 --- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py +++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py @@ -367,7 +367,8 @@ class CTCLossTestV2(test.TestCase): batch_size = 8 num_labels = 6 label_length = 5 - num_frames = 12 + minimum_logits_length = 10 + num_frames = minimum_logits_length + batch_size logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) labels = random_ops.random_uniform( [batch_size, label_length], minval=1, maxval=num_labels, @@ -379,7 +380,7 @@ class CTCLossTestV2(test.TestCase): label_lengths, maxlen=label_length, dtype=label_lengths.dtype) labels *= label_mask - logit_lengths = [num_frames] * batch_size + logit_lengths = math_ops.range(batch_size) + minimum_logits_length ctc_loss = ctc_ops.ctc_loss_dense( labels=labels, @@ -410,8 +411,8 @@ class CTCLossTestV2(test.TestCase): self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) self.assertAllClose( *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), - rtol=2e-06, - atol=2e-06) + rtol=4e-06, + atol=4e-06) @test_util.run_v1_only("b/120545219") def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self): diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index d0298fd8b6d..4b3a5dd7fe9 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -658,6 +658,17 @@ def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) + + # Applies the sequence mask for the gradient. It is enough to appply the mask + # only for ilabel_log_probs because olabel_log_probs already consider the + # mask. However, it is just safe and clean to apply it for the gradient. + max_logit_length = _get_dim(logits, 0) + logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, + dtypes.float32) + logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) + logit_mask = array_ops.expand_dims(logit_mask, axis=2) + grad *= logit_mask + loss = -log_likelihood return loss, grad