diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 35461c4e..1b7f0e98 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -14,10 +14,11 @@ using FSTMATCH = fst::SortedMatcher; -DecoderState* decoder_init(const Alphabet &alphabet, - int class_dim, - Scorer* ext_scorer) { - +DecoderState* +decoder_init(const Alphabet &alphabet, + int class_dim, + Scorer* ext_scorer) +{ // dimension check VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1, "The shape of probs does not match with " @@ -47,16 +48,17 @@ DecoderState* decoder_init(const Alphabet &alphabet, return state; } -void decoder_next(const double *probs, - const Alphabet &alphabet, - DecoderState *state, - int time_dim, - int class_dim, - double cutoff_prob, - size_t cutoff_top_n, - size_t beam_size, - Scorer *ext_scorer) { - +void +decoder_next(const double *probs, + const Alphabet &alphabet, + DecoderState *state, + int time_dim, + int class_dim, + double cutoff_prob, + size_t cutoff_top_n, + size_t beam_size, + Scorer *ext_scorer) +{ // prefix search over time for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) { auto *prob = &probs[rel_time_step*class_dim]; @@ -155,39 +157,47 @@ void decoder_next(const double *probs, } // end of loop over time } -std::vector decoder_decode(DecoderState *state, - const Alphabet &alphabet, - size_t beam_size, - Scorer* ext_scorer) { +std::vector +decoder_decode(DecoderState *state, + const Alphabet &alphabet, + size_t beam_size, + Scorer* ext_scorer) +{ + std::vector prefixes_copy = state->prefixes; + std::unordered_map scores; + for (PathTrie* prefix : prefixes_copy) { + scores[prefix] = prefix->score; + } // score the last word of each prefix that doesn't end with space if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) { - auto prefix = state->prefixes[i]; + for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) { + auto prefix = prefixes_copy[i]; if (!prefix->is_empty() && prefix->character != state->space_id) { float score = 0.0; std::vector ngram = ext_scorer->make_ngram(prefix); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score += ext_scorer->beta; - prefix->score += score; + scores[prefix] += score; } } } - size_t num_prefixes = std::min(state->prefixes.size(), beam_size); - std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare); + using namespace std::placeholders; + size_t num_prefixes = std::min(prefixes_copy.size(), beam_size); + std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores)); //TODO: expose this as an API parameter const int top_paths = 1; // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < top_paths && i < state->prefixes.size(); ++i) { - double approx_ctc = state->prefixes[i]->score; + for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) { + double approx_ctc = scores[prefixes_copy[i]]; if (ext_scorer != nullptr) { std::vector output; std::vector timesteps; - state->prefixes[i]->get_path_vec(output, timesteps); + prefixes_copy[i]->get_path_vec(output, timesteps); auto prefix_length = output.size(); auto words = ext_scorer->split_labels(output); // remove word insert @@ -195,10 +205,10 @@ std::vector decoder_decode(DecoderState *state, // remove language model weight: approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; } - state->prefixes[i]->approx_ctc = approx_ctc; + prefixes_copy[i]->approx_ctc = approx_ctc; } - return get_beam_search_result(state->prefixes, top_paths); + return get_beam_search_result(prefixes_copy, top_paths); } std::vector ctc_beam_search_decoder( diff --git a/native_client/ctcdecode/decoder_utils.cpp b/native_client/ctcdecode/decoder_utils.cpp index e8a3da77..b743d92c 100644 --- a/native_client/ctcdecode/decoder_utils.cpp +++ b/native_client/ctcdecode/decoder_utils.cpp @@ -112,6 +112,18 @@ bool prefix_compare(const PathTrie *x, const PathTrie *y) { } } +bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map& scores) { + if (scores.at(x) == scores.at(y)) { + if (x->character == y->character) { + return false; + } else { + return (x->character < y->character); + } + } else { + return scores.at(x) > scores.at(y); + } +} + void add_word_to_fst(const std::vector &word, fst::StdVectorFst *dictionary) { if (dictionary->NumStates() == 0) { diff --git a/native_client/ctcdecode/decoder_utils.h b/native_client/ctcdecode/decoder_utils.h index 54ddf5b7..797810b7 100644 --- a/native_client/ctcdecode/decoder_utils.h +++ b/native_client/ctcdecode/decoder_utils.h @@ -67,6 +67,8 @@ std::vector get_beam_search_result( // Functor for prefix comparsion bool prefix_compare(const PathTrie *x, const PathTrie *y); +bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map& scores); + /* Get length of utf8 encoding string * See: http://stackoverflow.com/a/4063229 */ diff --git a/taskcluster/tc-cpp-ds-tests.sh b/taskcluster/tc-cpp-ds-tests.sh index 7acd27a7..194abc41 100644 --- a/taskcluster/tc-cpp-ds-tests.sh +++ b/taskcluster/tc-cpp-ds-tests.sh @@ -13,3 +13,5 @@ check_tensorflow_version run_all_inference_tests run_multi_inference_tests + +run_cpp_only_inference_tests diff --git a/taskcluster/tc-tests-utils.sh b/taskcluster/tc-tests-utils.sh index dc1e7f3c..11efd016 100755 --- a/taskcluster/tc-tests-utils.sh +++ b/taskcluster/tc-tests-utils.sh @@ -460,6 +460,15 @@ run_multi_inference_tests() assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status" } +run_cpp_only_inference_tests() +{ + set +e + phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1) + status=$? + set -e + assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_intermediate_decode}" "$status" +} + android_run_tests() { cd ${DS_DSDIR}/native_client/java/