Merge pull request #3251 from Jendker/ctc_multiple_transc

Add num_results param to ctc_beam_search_decoder
This commit is contained in:
lissyx 2020-08-17 20:07:59 +02:00 committed by GitHub
commit 90e04fb365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 10 deletions

View File

@ -95,7 +95,8 @@ def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
scorer=None): scorer=None,
num_results=1):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
@ -115,13 +116,15 @@ def ctc_beam_search_decoder(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word :param scorer: External scorer for partially decoded sentence, e.g. word
count or language model. count or language model.
:type scorer: Scorer :type scorer: Scorer
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding :return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence. results, in descending order of the confidence.
:rtype: list :rtype: list
""" """
beam_results = swigwrapper.ctc_beam_search_decoder( beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n, probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
scorer) scorer, num_results)
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
return beam_results return beam_results
@ -133,7 +136,8 @@ def ctc_beam_search_decoder_batch(probs_seq,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
scorer=None): scorer=None,
num_results=1):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
@ -157,11 +161,13 @@ def ctc_beam_search_decoder_batch(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word :param scorer: External scorer for partially decoded sentence, e.g. word
count or language model. count or language model.
:type scorer: Scorer :type scorer: Scorer
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding :return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence. results, in descending order of the confidence.
:rtype: list :rtype: list
""" """
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer) batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
batch_beam_results = [ batch_beam_results = [
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results for beam_results in batch_beam_results

View File

@ -223,12 +223,13 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer) std::shared_ptr<Scorer> ext_scorer,
size_t num_results)
{ {
DecoderState state; DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
state.next(probs, time_dim, class_dim); state.next(probs, time_dim, class_dim);
return state.decode(); return state.decode(num_results);
} }
std::vector<std::vector<Output>> std::vector<std::vector<Output>>
@ -244,7 +245,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer) std::shared_ptr<Scorer> ext_scorer,
size_t num_results)
{ {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element"); VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
@ -262,7 +264,8 @@ ctc_beam_search_decoder_batch(
beam_size, beam_size,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer)); ext_scorer,
num_results));
} }
// get decoding results // get decoding results

View File

@ -87,6 +87,7 @@ public:
* ext_scorer: External scorer to evaluate a prefix, which consists of * ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term. * n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer. * Default null, decoding the input sample without scorer.
* num_results: Number of beams to return.
* Return: * Return:
* A vector where each element is a pair of score and decoding result, * A vector where each element is a pair of score and decoding result,
* in descending order. * in descending order.
@ -100,7 +101,8 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer); std::shared_ptr<Scorer> ext_scorer,
size_t num_results=1);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
* Parameters: * Parameters:
@ -114,6 +116,7 @@ std::vector<Output> ctc_beam_search_decoder(
* ext_scorer: External scorer to evaluate a prefix, which consists of * ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term. * n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer. * Default null, decoding the input sample without scorer.
* num_results: Number of beams to return.
* Return: * Return:
* A 2-D vector where each element is a vector of beam search decoding * A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample. * result for one audio sample.
@ -131,6 +134,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer); std::shared_ptr<Scorer> ext_scorer,
size_t num_results=1);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_