Merge pull request #3251 from Jendker/ctc_multiple_transc
Add num_results param to ctc_beam_search_decoder
This commit is contained in:
commit
90e04fb365
@ -95,7 +95,8 @@ def ctc_beam_search_decoder(probs_seq,
|
|||||||
beam_size,
|
beam_size,
|
||||||
cutoff_prob=1.0,
|
cutoff_prob=1.0,
|
||||||
cutoff_top_n=40,
|
cutoff_top_n=40,
|
||||||
scorer=None):
|
scorer=None,
|
||||||
|
num_results=1):
|
||||||
"""Wrapper for the CTC Beam Search Decoder.
|
"""Wrapper for the CTC Beam Search Decoder.
|
||||||
|
|
||||||
:param probs_seq: 2-D list of probability distributions over each time
|
:param probs_seq: 2-D list of probability distributions over each time
|
||||||
@ -115,13 +116,15 @@ def ctc_beam_search_decoder(probs_seq,
|
|||||||
:param scorer: External scorer for partially decoded sentence, e.g. word
|
:param scorer: External scorer for partially decoded sentence, e.g. word
|
||||||
count or language model.
|
count or language model.
|
||||||
:type scorer: Scorer
|
:type scorer: Scorer
|
||||||
|
:param num_results: Number of beams to return.
|
||||||
|
:type num_results: int
|
||||||
:return: List of tuples of confidence and sentence as decoding
|
:return: List of tuples of confidence and sentence as decoding
|
||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||||
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||||
scorer)
|
scorer, num_results)
|
||||||
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||||
return beam_results
|
return beam_results
|
||||||
|
|
||||||
@ -133,7 +136,8 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||||||
num_processes,
|
num_processes,
|
||||||
cutoff_prob=1.0,
|
cutoff_prob=1.0,
|
||||||
cutoff_top_n=40,
|
cutoff_top_n=40,
|
||||||
scorer=None):
|
scorer=None,
|
||||||
|
num_results=1):
|
||||||
"""Wrapper for the batched CTC beam search decoder.
|
"""Wrapper for the batched CTC beam search decoder.
|
||||||
|
|
||||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||||
@ -157,11 +161,13 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||||||
:param scorer: External scorer for partially decoded sentence, e.g. word
|
:param scorer: External scorer for partially decoded sentence, e.g. word
|
||||||
count or language model.
|
count or language model.
|
||||||
:type scorer: Scorer
|
:type scorer: Scorer
|
||||||
|
:param num_results: Number of beams to return.
|
||||||
|
:type num_results: int
|
||||||
:return: List of tuples of confidence and sentence as decoding
|
:return: List of tuples of confidence and sentence as decoding
|
||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
|
||||||
batch_beam_results = [
|
batch_beam_results = [
|
||||||
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||||
for beam_results in batch_beam_results
|
for beam_results in batch_beam_results
|
||||||
|
@ -223,12 +223,13 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
std::shared_ptr<Scorer> ext_scorer)
|
std::shared_ptr<Scorer> ext_scorer,
|
||||||
|
size_t num_results)
|
||||||
{
|
{
|
||||||
DecoderState state;
|
DecoderState state;
|
||||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||||
state.next(probs, time_dim, class_dim);
|
state.next(probs, time_dim, class_dim);
|
||||||
return state.decode();
|
return state.decode(num_results);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Output>>
|
std::vector<std::vector<Output>>
|
||||||
@ -244,7 +245,8 @@ ctc_beam_search_decoder_batch(
|
|||||||
size_t num_processes,
|
size_t num_processes,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
std::shared_ptr<Scorer> ext_scorer)
|
std::shared_ptr<Scorer> ext_scorer,
|
||||||
|
size_t num_results)
|
||||||
{
|
{
|
||||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||||
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
||||||
@ -262,7 +264,8 @@ ctc_beam_search_decoder_batch(
|
|||||||
beam_size,
|
beam_size,
|
||||||
cutoff_prob,
|
cutoff_prob,
|
||||||
cutoff_top_n,
|
cutoff_top_n,
|
||||||
ext_scorer));
|
ext_scorer,
|
||||||
|
num_results));
|
||||||
}
|
}
|
||||||
|
|
||||||
// get decoding results
|
// get decoding results
|
||||||
|
@ -87,6 +87,7 @@ public:
|
|||||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
* n-gram language model scoring and word insertion term.
|
* n-gram language model scoring and word insertion term.
|
||||||
* Default null, decoding the input sample without scorer.
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* num_results: Number of beams to return.
|
||||||
* Return:
|
* Return:
|
||||||
* A vector where each element is a pair of score and decoding result,
|
* A vector where each element is a pair of score and decoding result,
|
||||||
* in descending order.
|
* in descending order.
|
||||||
@ -100,7 +101,8 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
std::shared_ptr<Scorer> ext_scorer);
|
std::shared_ptr<Scorer> ext_scorer,
|
||||||
|
size_t num_results=1);
|
||||||
|
|
||||||
/* CTC Beam Search Decoder for batch data
|
/* CTC Beam Search Decoder for batch data
|
||||||
* Parameters:
|
* Parameters:
|
||||||
@ -114,6 +116,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
* n-gram language model scoring and word insertion term.
|
* n-gram language model scoring and word insertion term.
|
||||||
* Default null, decoding the input sample without scorer.
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* num_results: Number of beams to return.
|
||||||
* Return:
|
* Return:
|
||||||
* A 2-D vector where each element is a vector of beam search decoding
|
* A 2-D vector where each element is a vector of beam search decoding
|
||||||
* result for one audio sample.
|
* result for one audio sample.
|
||||||
@ -131,6 +134,7 @@ ctc_beam_search_decoder_batch(
|
|||||||
size_t num_processes,
|
size_t num_processes,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
std::shared_ptr<Scorer> ext_scorer);
|
std::shared_ptr<Scorer> ext_scorer,
|
||||||
|
size_t num_results=1);
|
||||||
|
|
||||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user