diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index d12e81fc..a50f731f 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -176,9 +176,12 @@ std::vector decoder_decode(DecoderState *state, size_t num_prefixes = std::min(state->prefixes.size(), beam_size); std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare); + //TODO: expose this as an API parameter + const int top_paths = 1; + // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) { + for (size_t i = 0; i < top_paths && i < state->prefixes.size(); ++i) { double approx_ctc = state->prefixes[i]->score; if (ext_scorer != nullptr) { std::vector output; @@ -194,7 +197,7 @@ std::vector decoder_decode(DecoderState *state, state->prefixes[i]->approx_ctc = approx_ctc; } - return get_beam_search_result(state->prefixes, beam_size); + return get_beam_search_result(state->prefixes, top_paths); } std::vector ctc_beam_search_decoder( diff --git a/native_client/ctcdecode/decoder_utils.cpp b/native_client/ctcdecode/decoder_utils.cpp index c0505fb3..e8a3da77 100644 --- a/native_client/ctcdecode/decoder_utils.cpp +++ b/native_client/ctcdecode/decoder_utils.cpp @@ -41,21 +41,12 @@ std::vector> get_pruned_log_probs( std::vector get_beam_search_result( const std::vector &prefixes, - size_t beam_size) { - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); - } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + size_t top_paths) { std::vector output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { + for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) { Output output; - space_prefixes[i]->get_path_vec(output.tokens, output.timesteps); - output.probability = -space_prefixes[i]->approx_ctc; + prefixes[i]->get_path_vec(output.tokens, output.timesteps); + output.probability = -prefixes[i]->approx_ctc; output_vecs.push_back(output); } diff --git a/native_client/ctcdecode/decoder_utils.h b/native_client/ctcdecode/decoder_utils.h index f3c1977d..54ddf5b7 100644 --- a/native_client/ctcdecode/decoder_utils.h +++ b/native_client/ctcdecode/decoder_utils.h @@ -62,7 +62,7 @@ std::vector> get_pruned_log_probs( // Get beam search result from prefixes in trie tree std::vector get_beam_search_result( const std::vector &prefixes, - size_t beam_size); + size_t top_paths); // Functor for prefix comparsion bool prefix_compare(const PathTrie *x, const PathTrie *y);