Merge pull request #36316 from faustomorales:fix/#35799-ctc-decode
PiperOrigin-RevId: 313619800 Change-Id: I4843b83f42a5425d722d9b509d241941745fb5ff
This commit is contained in:
commit
c197499613
@ -5948,10 +5948,13 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
||||
contains the decoded sequence.
|
||||
If `false`, returns the `top_paths` most probable
|
||||
decoded sequences.
|
||||
Each decoded sequence has shape (samples, time_steps).
|
||||
Important: blank labels are returned as `-1`.
|
||||
Tensor `(top_paths, )` that contains
|
||||
the log probability of each decoded sequence.
|
||||
"""
|
||||
input_shape = shape(y_pred)
|
||||
samples, steps = input_shape[0], input_shape[1]
|
||||
y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
|
||||
input_length = math_ops.cast(input_length, dtypes_module.int32)
|
||||
|
||||
@ -5966,7 +5969,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
||||
top_paths=top_paths)
|
||||
decoded_dense = [
|
||||
sparse_ops.sparse_to_dense(
|
||||
st.indices, st.dense_shape, st.values, default_value=-1)
|
||||
st.indices, (samples, steps), st.values, default_value=-1)
|
||||
for st in decoded
|
||||
]
|
||||
return (decoded_dense, log_prob)
|
||||
|
@ -1762,7 +1762,10 @@ class TestCTC(test.TestCase):
|
||||
-3.777835 # output beam 1
|
||||
], np.float32)[np.newaxis, :]
|
||||
|
||||
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
|
||||
decode_truth = [
|
||||
np.array([1, 0, -1, -1, -1, -1, -1]),
|
||||
np.array([0, 1, 0, -1, -1, -1, -1])
|
||||
]
|
||||
beam_width = 2
|
||||
top_paths = 2
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user