diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py index e7f1f8a5e85..19918496fbd 100644 --- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py +++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py @@ -460,6 +460,69 @@ class CTCLossTestV2(test.TestCase): time_major=True) tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] + with self.cached_session(): + for _ in range(32): + self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) + self.assertAllClose( + *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), + rtol=2e-06, + atol=2e-06) + + @test_util.run_v1_only("b/120545219") + def testCtcLossDenseUniqueFastPathWithBlankIndexIsSameAsCtcLoss(self): + random_seed.set_random_seed(5) + + batch_size = 8 + num_labels = 6 + label_length = 5 + num_frames = 12 + logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) + labels = random_ops.random_uniform([batch_size, label_length], + minval=0, + maxval=num_labels - 1, + dtype=dtypes.int64) + + label_lengths = random_ops.random_uniform([batch_size], + minval=2, + maxval=label_length, + dtype=dtypes.int64) + label_mask = array_ops.sequence_mask( + label_lengths, maxlen=label_length, dtype=label_lengths.dtype) + labels *= label_mask + + logit_lengths = [num_frames] * batch_size + + tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) + tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(tf_ctc_loss_labels, + label_lengths) + + tf_nn_ctc_loss = ctc_ops.ctc_loss( + labels=tf_ctc_loss_labels, + inputs=logits, + sequence_length=logit_lengths, + time_major=True) + tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] + + # Shift the blank logits/labels to be somewhere in the middle. + blank_index = 2 + shifted_logits = array_ops.concat([ + logits[:, :, :blank_index], + logits[:, :, -1:], + logits[:, :, blank_index:-1], + ], + axis=2) + shifted_labels = array_ops.where_v2(labels < blank_index, labels, + labels + 1) + + ctc_loss = ctc_ops.ctc_loss_dense( + labels=shifted_labels, + logits=shifted_logits, + label_length=label_lengths, + logit_length=logit_lengths, + blank_index=blank_index, + unique=ctc_ops.ctc_unique_labels(shifted_labels)) + ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0] + with self.cached_session() as sess: for _ in range(32): self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) @@ -773,6 +836,41 @@ class CTCLossTestV2(test.TestCase): [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], ]) + def testStateToOlabelUniqueSinglePath(self): + labels = [ + [3, 4, 3], + [1, 0, 0], + ] + num_labels = 8 + + # 3 frames, 2 batch, 8 states (4 label, 4 blank). + # + # There is only single valid path for each sequence because the frame + # lengths and the label lengths are the same. + states = [[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]] + labels = ops.convert_to_tensor(labels) + states = math_ops.log(states) + olabel = ctc_ops._state_to_olabel_unique(labels, num_labels, states, + ctc_ops.ctc_unique_labels(labels)) + olabel = math_ops.exp(olabel) + blank = olabel[:, :, 0] + + self.assertAllClose(blank, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + self.assertAllClose(olabel[:, :, 1:], + [ + [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + ]) + @test_util.run_deprecated_v1 def testScan(self): with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"): diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 4b3a5dd7fe9..d18799c5224 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -601,9 +601,18 @@ def _state_to_olabel_unique(labels, num_labels, states, unique): updates=batch_state_major, shape=[batch_size * num_labels, num_frames]) scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) + + mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool) + mask = array_ops.scatter_nd( + indices=indices, + updates=mask, + shape=[batch_size * num_labels, num_frames]) + mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames]) + scatter = array_ops.where( - math_ops.equal(scatter, 0.0), - array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)), scatter) + mask, scatter, + array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0))) + label_olabels = array_ops.transpose(scatter, [2, 0, 1]) label_olabels = label_olabels[:, :, 1:] @@ -1010,6 +1019,14 @@ def ctc_loss_dense(labels, if unique: unique_y, unique_idx = unique + if blank_index != 0: + unique_y = array_ops.where(unique_y < blank_index, unique_y + 1, + unique_y) + label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1 + max_label_length = _get_dim(unique_y, 1) + label_mask = array_ops.sequence_mask(label_mask_len, max_label_length) + unique_y = array_ops.where(label_mask, unique_y, + array_ops.zeros_like(unique_y)) args.extend([unique_y, unique_idx]) @custom_gradient.custom_gradient