Merge pull request #2145 from mozilla/decoder-optimizations

Decoder optimizations
This commit is contained in:
Reuben Morais 2019-06-04 13:37:59 -03:00 committed by GitHub
commit 10d98e1df9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 16 deletions

View File

@ -176,9 +176,12 @@ std::vector<Output> decoder_decode(DecoderState *state,
size_t num_prefixes = std::min(state->prefixes.size(), beam_size); size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare); std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
//TODO: expose this as an API parameter
const int top_paths = 1;
// compute aproximate ctc score as the return score, without affecting the // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable. // return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) { for (size_t i = 0; i < top_paths && i < state->prefixes.size(); ++i) {
double approx_ctc = state->prefixes[i]->score; double approx_ctc = state->prefixes[i]->score;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
std::vector<int> output; std::vector<int> output;
@ -194,7 +197,7 @@ std::vector<Output> decoder_decode(DecoderState *state,
state->prefixes[i]->approx_ctc = approx_ctc; state->prefixes[i]->approx_ctc = approx_ctc;
} }
return get_beam_search_result(state->prefixes, beam_size); return get_beam_search_result(state->prefixes, top_paths);
} }
std::vector<Output> ctc_beam_search_decoder( std::vector<Output> ctc_beam_search_decoder(

View File

@ -41,21 +41,12 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
std::vector<Output> get_beam_search_result( std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, const std::vector<PathTrie *> &prefixes,
size_t beam_size) { size_t top_paths) {
// allow for the post processing
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<Output> output_vecs; std::vector<Output> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) {
Output output; Output output;
space_prefixes[i]->get_path_vec(output.tokens, output.timesteps); prefixes[i]->get_path_vec(output.tokens, output.timesteps);
output.probability = -space_prefixes[i]->approx_ctc; output.probability = -prefixes[i]->approx_ctc;
output_vecs.push_back(output); output_vecs.push_back(output);
} }

View File

@ -62,7 +62,7 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
// Get beam search result from prefixes in trie tree // Get beam search result from prefixes in trie tree
std::vector<Output> get_beam_search_result( std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, const std::vector<PathTrie *> &prefixes,
size_t beam_size); size_t top_paths);
// Functor for prefix comparsion // Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y); bool prefix_compare(const PathTrie *x, const PathTrie *y);