Only create Output structure for beams that are returned in the API
This commit is contained in:
parent
1201739af2
commit
a46288e1c8
@ -176,9 +176,12 @@ std::vector<Output> decoder_decode(DecoderState *state,
|
|||||||
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||||
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
|
||||||
|
//TODO: expose this as an API parameter
|
||||||
|
const int top_paths = 1;
|
||||||
|
|
||||||
// compute aproximate ctc score as the return score, without affecting the
|
// compute aproximate ctc score as the return score, without affecting the
|
||||||
// return order of decoding result. To delete when decoder gets stable.
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
|
for (size_t i = 0; i < top_paths && i < state->prefixes.size(); ++i) {
|
||||||
double approx_ctc = state->prefixes[i]->score;
|
double approx_ctc = state->prefixes[i]->score;
|
||||||
if (ext_scorer != nullptr) {
|
if (ext_scorer != nullptr) {
|
||||||
std::vector<int> output;
|
std::vector<int> output;
|
||||||
@ -194,7 +197,7 @@ std::vector<Output> decoder_decode(DecoderState *state,
|
|||||||
state->prefixes[i]->approx_ctc = approx_ctc;
|
state->prefixes[i]->approx_ctc = approx_ctc;
|
||||||
}
|
}
|
||||||
|
|
||||||
return get_beam_search_result(state->prefixes, beam_size);
|
return get_beam_search_result(state->prefixes, top_paths);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Output> ctc_beam_search_decoder(
|
std::vector<Output> ctc_beam_search_decoder(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user