Merge pull request #36316 from faustomorales:fix/#35799-ctc-decode

PiperOrigin-RevId: 313619800
Change-Id: I4843b83f42a5425d722d9b509d241941745fb5ff
This commit is contained in:
TensorFlower Gardener 2020-05-28 11:11:44 -07:00
commit c197499613
2 changed files with 8 additions and 2 deletions

View File

@ -5948,10 +5948,13 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
contains the decoded sequence.
If `false`, returns the `top_paths` most probable
decoded sequences.
Each decoded sequence has shape (samples, time_steps).
Important: blank labels are returned as `-1`.
Tensor `(top_paths, )` that contains
the log probability of each decoded sequence.
"""
input_shape = shape(y_pred)
samples, steps = input_shape[0], input_shape[1]
y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
input_length = math_ops.cast(input_length, dtypes_module.int32)
@ -5966,7 +5969,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
top_paths=top_paths)
decoded_dense = [
sparse_ops.sparse_to_dense(
st.indices, st.dense_shape, st.values, default_value=-1)
st.indices, (samples, steps), st.values, default_value=-1)
for st in decoded
]
return (decoded_dense, log_prob)

View File

@ -1762,7 +1762,10 @@ class TestCTC(test.TestCase):
-3.777835 # output beam 1
], np.float32)[np.newaxis, :]
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
decode_truth = [
np.array([1, 0, -1, -1, -1, -1, -1]),
np.array([0, 1, 0, -1, -1, -1, -1])
]
beam_width = 2
top_paths = 2