Avoid sorting prefix array twice
This commit is contained in:
parent
a46288e1c8
commit
1c87bf781a
@ -41,21 +41,12 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
|||||||
|
|
||||||
std::vector<Output> get_beam_search_result(
|
std::vector<Output> get_beam_search_result(
|
||||||
const std::vector<PathTrie *> &prefixes,
|
const std::vector<PathTrie *> &prefixes,
|
||||||
size_t beam_size) {
|
size_t top_paths) {
|
||||||
// allow for the post processing
|
|
||||||
std::vector<PathTrie *> space_prefixes;
|
|
||||||
if (space_prefixes.empty()) {
|
|
||||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
||||||
space_prefixes.push_back(prefixes[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
|
||||||
std::vector<Output> output_vecs;
|
std::vector<Output> output_vecs;
|
||||||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) {
|
||||||
Output output;
|
Output output;
|
||||||
space_prefixes[i]->get_path_vec(output.tokens, output.timesteps);
|
prefixes[i]->get_path_vec(output.tokens, output.timesteps);
|
||||||
output.probability = -space_prefixes[i]->approx_ctc;
|
output.probability = -prefixes[i]->approx_ctc;
|
||||||
output_vecs.push_back(output);
|
output_vecs.push_back(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user