Add num_results param to ctc_beam_search_decoder
This commit is contained in:
parent
ce71910ab4
commit
c20af74d51
@ -95,7 +95,8 @@ def ctc_beam_search_decoder(probs_seq,
|
||||
beam_size,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
scorer=None):
|
||||
scorer=None,
|
||||
num_results=1):
|
||||
"""Wrapper for the CTC Beam Search Decoder.
|
||||
|
||||
:param probs_seq: 2-D list of probability distributions over each time
|
||||
@ -115,13 +116,15 @@ def ctc_beam_search_decoder(probs_seq,
|
||||
:param scorer: External scorer for partially decoded sentence, e.g. word
|
||||
count or language model.
|
||||
:type scorer: Scorer
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of tuples of confidence and sentence as decoding
|
||||
results, in descending order of the confidence.
|
||||
:rtype: list
|
||||
"""
|
||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||
scorer)
|
||||
scorer, num_results)
|
||||
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||
return beam_results
|
||||
|
||||
@ -133,7 +136,8 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
||||
num_processes,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
scorer=None):
|
||||
scorer=None,
|
||||
num_results=1):
|
||||
"""Wrapper for the batched CTC beam search decoder.
|
||||
|
||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||
@ -157,11 +161,13 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
||||
:param scorer: External scorer for partially decoded sentence, e.g. word
|
||||
count or language model.
|
||||
:type scorer: Scorer
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of tuples of confidence and sentence as decoding
|
||||
results, in descending order of the confidence.
|
||||
:rtype: list
|
||||
"""
|
||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
|
||||
batch_beam_results = [
|
||||
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||
for beam_results in batch_beam_results
|
||||
|
@ -223,12 +223,13 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer)
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
size_t num_results)
|
||||
{
|
||||
DecoderState state;
|
||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||
state.next(probs, time_dim, class_dim);
|
||||
return state.decode();
|
||||
return state.decode(num_results);
|
||||
}
|
||||
|
||||
std::vector<std::vector<Output>>
|
||||
@ -244,7 +245,8 @@ ctc_beam_search_decoder_batch(
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer)
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
size_t num_results)
|
||||
{
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
||||
@ -262,7 +264,8 @@ ctc_beam_search_decoder_batch(
|
||||
beam_size,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
ext_scorer));
|
||||
ext_scorer,
|
||||
num_results));
|
||||
}
|
||||
|
||||
// get decoding results
|
||||
|
@ -87,6 +87,7 @@ public:
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* num_results: Number of beams to return.
|
||||
* Return:
|
||||
* A vector where each element is a pair of score and decoding result,
|
||||
* in descending order.
|
||||
@ -100,7 +101,8 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer);
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
size_t num_results=1);
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
* Parameters:
|
||||
@ -114,6 +116,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* num_results: Number of beams to return.
|
||||
* Return:
|
||||
* A 2-D vector where each element is a vector of beam search decoding
|
||||
* result for one audio sample.
|
||||
@ -131,6 +134,7 @@ ctc_beam_search_decoder_batch(
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer);
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
size_t num_results=1);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
Loading…
x
Reference in New Issue
Block a user