Fix a bug in ctc_loss_dense with unique.

1. Unique label has to consider the case where blank label is not 0.

2. The scattering mechanism assumes that 0.0 always corresponds to a padding region, but this is not the case when there is only single valid path (0.0 = log(1.0)). This happen when the lengths of the logits and the label are the same.

PiperOrigin-RevId: 296090785
Change-Id: I803508252e688571bca531b1aa95dd2160902d4c
This commit is contained in:
A. Unique TensorFlower 2020-02-19 17:01:57 -08:00 committed by TensorFlower Gardener
parent 513c39fc74
commit abaab5b360
2 changed files with 117 additions and 2 deletions

View File

@ -460,6 +460,69 @@ class CTCLossTestV2(test.TestCase):
time_major=True)
tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
with self.cached_session():
for _ in range(32):
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
self.assertAllClose(
*self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
rtol=2e-06,
atol=2e-06)
@test_util.run_v1_only("b/120545219")
def testCtcLossDenseUniqueFastPathWithBlankIndexIsSameAsCtcLoss(self):
random_seed.set_random_seed(5)
batch_size = 8
num_labels = 6
label_length = 5
num_frames = 12
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
labels = random_ops.random_uniform([batch_size, label_length],
minval=0,
maxval=num_labels - 1,
dtype=dtypes.int64)
label_lengths = random_ops.random_uniform([batch_size],
minval=2,
maxval=label_length,
dtype=dtypes.int64)
label_mask = array_ops.sequence_mask(
label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
labels *= label_mask
logit_lengths = [num_frames] * batch_size
tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(tf_ctc_loss_labels,
label_lengths)
tf_nn_ctc_loss = ctc_ops.ctc_loss(
labels=tf_ctc_loss_labels,
inputs=logits,
sequence_length=logit_lengths,
time_major=True)
tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
# Shift the blank logits/labels to be somewhere in the middle.
blank_index = 2
shifted_logits = array_ops.concat([
logits[:, :, :blank_index],
logits[:, :, -1:],
logits[:, :, blank_index:-1],
],
axis=2)
shifted_labels = array_ops.where_v2(labels < blank_index, labels,
labels + 1)
ctc_loss = ctc_ops.ctc_loss_dense(
labels=shifted_labels,
logits=shifted_logits,
label_length=label_lengths,
logit_length=logit_lengths,
blank_index=blank_index,
unique=ctc_ops.ctc_unique_labels(shifted_labels))
ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
with self.cached_session() as sess:
for _ in range(32):
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
@ -773,6 +836,41 @@ class CTCLossTestV2(test.TestCase):
[22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
])
def testStateToOlabelUniqueSinglePath(self):
labels = [
[3, 4, 3],
[1, 0, 0],
]
num_labels = 8
# 3 frames, 2 batch, 8 states (4 label, 4 blank).
#
# There is only single valid path for each sequence because the frame
# lengths and the label lengths are the same.
states = [[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
labels = ops.convert_to_tensor(labels)
states = math_ops.log(states)
olabel = ctc_ops._state_to_olabel_unique(labels, num_labels, states,
ctc_ops.ctc_unique_labels(labels))
olabel = math_ops.exp(olabel)
blank = olabel[:, :, 0]
self.assertAllClose(blank, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
self.assertAllClose(olabel[:, :, 1:],
[
[[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
])
@test_util.run_deprecated_v1
def testScan(self):
with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):

View File

@ -601,9 +601,18 @@ def _state_to_olabel_unique(labels, num_labels, states, unique):
updates=batch_state_major,
shape=[batch_size * num_labels, num_frames])
scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
mask = array_ops.scatter_nd(
indices=indices,
updates=mask,
shape=[batch_size * num_labels, num_frames])
mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
scatter = array_ops.where(
math_ops.equal(scatter, 0.0),
array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)), scatter)
mask, scatter,
array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
label_olabels = array_ops.transpose(scatter, [2, 0, 1])
label_olabels = label_olabels[:, :, 1:]
@ -1010,6 +1019,14 @@ def ctc_loss_dense(labels,
if unique:
unique_y, unique_idx = unique
if blank_index != 0:
unique_y = array_ops.where(unique_y < blank_index, unique_y + 1,
unique_y)
label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1
max_label_length = _get_dim(unique_y, 1)
label_mask = array_ops.sequence_mask(label_mask_len, max_label_length)
unique_y = array_ops.where(label_mask, unique_y,
array_ops.zeros_like(unique_y))
args.extend([unique_y, unique_idx])
@custom_gradient.custom_gradient