Pad keras.backend.ctc_decode output to consistent shape.

This commit is contained in:
Fausto Morales 2020-01-28 23:41:39 -06:00
parent 0a3c298880
commit a24a75aee1
2 changed files with 5 additions and 2 deletions

View File

@ -5824,10 +5824,13 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
contains the decoded sequence.
If `false`, returns the `top_paths` most probable
decoded sequences.
Each decoded sequence has shape (samples, time_steps).
Important: blank labels are returned as `-1`.
Tensor `(top_paths, )` that contains
the log probability of each decoded sequence.
"""
input_shape = shape(y_pred)
samples, steps = input_shape[0], input_shape[1]
y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
input_length = math_ops.cast(input_length, dtypes_module.int32)
@ -5842,7 +5845,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
top_paths=top_paths)
decoded_dense = [
sparse_ops.sparse_to_dense(
st.indices, st.dense_shape, st.values, default_value=-1)
st.indices, (samples, steps), st.values, default_value=-1)
for st in decoded
]
return (decoded_dense, log_prob)

View File

@ -1771,7 +1771,7 @@ class TestCTC(test.TestCase):
-3.777835 # output beam 1
], np.float32)[np.newaxis, :]
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
decode_truth = [np.array([1, 0, -1, -1, -1, -1, -1]), np.array([0, 1, 0, -1, -1, -1 ,-1])]
beam_width = 2
top_paths = 2