Pad keras.backend.ctc_decode output to consistent shape.
This commit is contained in:
parent
0a3c298880
commit
a24a75aee1
@ -5824,10 +5824,13 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
||||
contains the decoded sequence.
|
||||
If `false`, returns the `top_paths` most probable
|
||||
decoded sequences.
|
||||
Each decoded sequence has shape (samples, time_steps).
|
||||
Important: blank labels are returned as `-1`.
|
||||
Tensor `(top_paths, )` that contains
|
||||
the log probability of each decoded sequence.
|
||||
"""
|
||||
input_shape = shape(y_pred)
|
||||
samples, steps = input_shape[0], input_shape[1]
|
||||
y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
|
||||
input_length = math_ops.cast(input_length, dtypes_module.int32)
|
||||
|
||||
@ -5842,7 +5845,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
||||
top_paths=top_paths)
|
||||
decoded_dense = [
|
||||
sparse_ops.sparse_to_dense(
|
||||
st.indices, st.dense_shape, st.values, default_value=-1)
|
||||
st.indices, (samples, steps), st.values, default_value=-1)
|
||||
for st in decoded
|
||||
]
|
||||
return (decoded_dense, log_prob)
|
||||
|
@ -1771,7 +1771,7 @@ class TestCTC(test.TestCase):
|
||||
-3.777835 # output beam 1
|
||||
], np.float32)[np.newaxis, :]
|
||||
|
||||
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
|
||||
decode_truth = [np.array([1, 0, -1, -1, -1, -1, -1]), np.array([0, 1, 0, -1, -1, -1 ,-1])]
|
||||
beam_width = 2
|
||||
top_paths = 2
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user