Fix a bug in ctc_loss_dense with unique.
1. Unique label has to consider the case where blank label is not 0. 2. The scattering mechanism assumes that 0.0 always corresponds to a padding region, but this is not the case when there is only single valid path (0.0 = log(1.0)). This happen when the lengths of the logits and the label are the same. PiperOrigin-RevId: 296090785 Change-Id: I803508252e688571bca531b1aa95dd2160902d4c
This commit is contained in:
parent
513c39fc74
commit
abaab5b360
@ -460,6 +460,69 @@ class CTCLossTestV2(test.TestCase):
|
||||
time_major=True)
|
||||
tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
|
||||
|
||||
with self.cached_session():
|
||||
for _ in range(32):
|
||||
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
|
||||
self.assertAllClose(
|
||||
*self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
|
||||
rtol=2e-06,
|
||||
atol=2e-06)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testCtcLossDenseUniqueFastPathWithBlankIndexIsSameAsCtcLoss(self):
|
||||
random_seed.set_random_seed(5)
|
||||
|
||||
batch_size = 8
|
||||
num_labels = 6
|
||||
label_length = 5
|
||||
num_frames = 12
|
||||
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
|
||||
labels = random_ops.random_uniform([batch_size, label_length],
|
||||
minval=0,
|
||||
maxval=num_labels - 1,
|
||||
dtype=dtypes.int64)
|
||||
|
||||
label_lengths = random_ops.random_uniform([batch_size],
|
||||
minval=2,
|
||||
maxval=label_length,
|
||||
dtype=dtypes.int64)
|
||||
label_mask = array_ops.sequence_mask(
|
||||
label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
|
||||
labels *= label_mask
|
||||
|
||||
logit_lengths = [num_frames] * batch_size
|
||||
|
||||
tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
|
||||
tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(tf_ctc_loss_labels,
|
||||
label_lengths)
|
||||
|
||||
tf_nn_ctc_loss = ctc_ops.ctc_loss(
|
||||
labels=tf_ctc_loss_labels,
|
||||
inputs=logits,
|
||||
sequence_length=logit_lengths,
|
||||
time_major=True)
|
||||
tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
|
||||
|
||||
# Shift the blank logits/labels to be somewhere in the middle.
|
||||
blank_index = 2
|
||||
shifted_logits = array_ops.concat([
|
||||
logits[:, :, :blank_index],
|
||||
logits[:, :, -1:],
|
||||
logits[:, :, blank_index:-1],
|
||||
],
|
||||
axis=2)
|
||||
shifted_labels = array_ops.where_v2(labels < blank_index, labels,
|
||||
labels + 1)
|
||||
|
||||
ctc_loss = ctc_ops.ctc_loss_dense(
|
||||
labels=shifted_labels,
|
||||
logits=shifted_logits,
|
||||
label_length=label_lengths,
|
||||
logit_length=logit_lengths,
|
||||
blank_index=blank_index,
|
||||
unique=ctc_ops.ctc_unique_labels(shifted_labels))
|
||||
ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(32):
|
||||
self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
|
||||
@ -773,6 +836,41 @@ class CTCLossTestV2(test.TestCase):
|
||||
[22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
])
|
||||
|
||||
def testStateToOlabelUniqueSinglePath(self):
|
||||
labels = [
|
||||
[3, 4, 3],
|
||||
[1, 0, 0],
|
||||
]
|
||||
num_labels = 8
|
||||
|
||||
# 3 frames, 2 batch, 8 states (4 label, 4 blank).
|
||||
#
|
||||
# There is only single valid path for each sequence because the frame
|
||||
# lengths and the label lengths are the same.
|
||||
states = [[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
states = math_ops.log(states)
|
||||
olabel = ctc_ops._state_to_olabel_unique(labels, num_labels, states,
|
||||
ctc_ops.ctc_unique_labels(labels))
|
||||
olabel = math_ops.exp(olabel)
|
||||
blank = olabel[:, :, 0]
|
||||
|
||||
self.assertAllClose(blank, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
|
||||
self.assertAllClose(olabel[:, :, 1:],
|
||||
[
|
||||
[[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testScan(self):
|
||||
with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
|
||||
|
@ -601,9 +601,18 @@ def _state_to_olabel_unique(labels, num_labels, states, unique):
|
||||
updates=batch_state_major,
|
||||
shape=[batch_size * num_labels, num_frames])
|
||||
scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
|
||||
|
||||
mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
|
||||
mask = array_ops.scatter_nd(
|
||||
indices=indices,
|
||||
updates=mask,
|
||||
shape=[batch_size * num_labels, num_frames])
|
||||
mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
|
||||
|
||||
scatter = array_ops.where(
|
||||
math_ops.equal(scatter, 0.0),
|
||||
array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)), scatter)
|
||||
mask, scatter,
|
||||
array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
|
||||
|
||||
label_olabels = array_ops.transpose(scatter, [2, 0, 1])
|
||||
label_olabels = label_olabels[:, :, 1:]
|
||||
|
||||
@ -1010,6 +1019,14 @@ def ctc_loss_dense(labels,
|
||||
|
||||
if unique:
|
||||
unique_y, unique_idx = unique
|
||||
if blank_index != 0:
|
||||
unique_y = array_ops.where(unique_y < blank_index, unique_y + 1,
|
||||
unique_y)
|
||||
label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1
|
||||
max_label_length = _get_dim(unique_y, 1)
|
||||
label_mask = array_ops.sequence_mask(label_mask_len, max_label_length)
|
||||
unique_y = array_ops.where(label_mask, unique_y,
|
||||
array_ops.zeros_like(unique_y))
|
||||
args.extend([unique_y, unique_idx])
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
|
Loading…
Reference in New Issue
Block a user