Avoid sorting prefix array twice

This commit is contained in:
Reuben Morais 2019-06-01 15:06:56 -03:00
parent a46288e1c8
commit 1c87bf781a

View File

@ -41,21 +41,12 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
size_t beam_size) {
// allow for the post processing
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
size_t top_paths) {
std::vector<Output> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) {
Output output;
space_prefixes[i]->get_path_vec(output.tokens, output.timesteps);
output.probability = -space_prefixes[i]->approx_ctc;
prefixes[i]->get_path_vec(output.tokens, output.timesteps);
output.probability = -prefixes[i]->approx_ctc;
output_vecs.push_back(output);
}