From 04a36fbf68b218199667b1b915f12cbcf884d925 Mon Sep 17 00:00:00 2001 From: godeffroy Date: Tue, 25 Aug 2020 12:08:14 +0200 Subject: [PATCH] The CTC decoder timesteps now corresponds to the timesteps of the most probable CTC path, instead of the earliest timesteps of all possible paths. --- .../ctcdecode/ctc_beam_search_decoder.cpp | 46 ++++++++++++++++--- native_client/ctcdecode/path_trie.cpp | 34 ++++++-------- native_client/ctcdecode/path_trie.h | 11 +++-- native_client/ctcdecode/scorer.cpp | 6 +-- 4 files changed, 64 insertions(+), 33 deletions(-) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index d46b6893..823313c8 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -96,24 +96,52 @@ DecoderState::next(const double *probs, if (full_beam && log_prob_c + prefix->score < min_cutoff) { break; } + if (prefix->score == -NUM_FLT_INF) { + continue; + } + if (!prefix->is_empty() && prefix->timesteps.empty()) { + // This should never happen. But we report it if it does. + std::cerr<<"error: non-empty prefix has empty timestep sequence"<score; + + // combine current path with previous ones with the same prefix + // the blank label comes last, so we can compare log_prob_nb_cur with log_p + if (prefix->log_prob_nb_cur < log_p) { + prefix->timesteps_cur = prefix->timesteps; + } prefix->log_prob_b_cur = - log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); + log_sum_exp(prefix->log_prob_b_cur, log_p); continue; } // repeated character if (c == prefix->character) { + // compute probability of current path + float log_p = log_prob_c + prefix->log_prob_nb_prev; + + // combine current path with previous ones with the same prefix + if (prefix->log_prob_nb_cur < log_p) { + prefix->timesteps_cur = prefix->timesteps; + } prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); + prefix->log_prob_nb_cur, log_p); } // get new prefix - auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c); + auto prefix_new = prefix->get_path_trie(c, log_prob_c); if (prefix_new != nullptr) { + // compute timesteps of current path + std::vector timesteps_new=prefix->timesteps; + timesteps_new.push_back(abs_time_step_); + + // compute probability of current path float log_p = -NUM_FLT_INF; if (c == prefix->character && @@ -144,6 +172,10 @@ DecoderState::next(const double *probs, } } + // combine current path with previous ones with the same prefix + if (prefix_new->log_prob_nb_cur < log_p) { + prefix_new->timesteps_cur = timesteps_new; + } prefix_new->log_prob_nb_cur = log_sum_exp(prefix_new->log_prob_nb_cur, log_p); } @@ -205,11 +237,13 @@ DecoderState::decode(size_t num_results) const std::vector outputs; outputs.reserve(num_returned); - for (size_t i = 0; i < num_returned; ++i) { + for (PathTrie* prefix : prefixes_copy) { Output output; - prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps); - output.confidence = scores[prefixes_copy[i]]; + output.tokens = prefix->get_path_vec(); + output.timesteps = prefix->timesteps; + output.confidence = scores[prefix]; outputs.push_back(output); + if(outputs.size()>=num_returned) break; } return outputs; diff --git a/native_client/ctcdecode/path_trie.cpp b/native_client/ctcdecode/path_trie.cpp index 7a04f693..55a81437 100644 --- a/native_client/ctcdecode/path_trie.cpp +++ b/native_client/ctcdecode/path_trie.cpp @@ -18,7 +18,6 @@ PathTrie::PathTrie() { ROOT_ = -1; character = ROOT_; - timestep = 0; exists_ = true; parent = nullptr; @@ -35,7 +34,7 @@ PathTrie::~PathTrie() { } } -PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) { +PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) { auto child = children_.begin(); for (; child != children_.end(); ++child) { if (child->first == new_char) { @@ -67,7 +66,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest } else { PathTrie* new_path = new PathTrie; new_path->character = new_char; - new_path->timestep = new_timestep; new_path->parent = this; new_path->dictionary_ = dictionary_; new_path->has_dictionary_ = true; @@ -93,7 +91,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest } else { PathTrie* new_path = new PathTrie; new_path->character = new_char; - new_path->timestep = new_timestep; new_path->parent = this; new_path->log_prob_c = cur_log_prob_c; children_.push_back(std::make_pair(new_char, new_path)); @@ -102,20 +99,18 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest } } -void PathTrie::get_path_vec(std::vector& output, std::vector& timesteps) { - // Recursive call: recurse back until stop condition, then append data in - // correct order as we walk back down the stack in the lines below. - if (parent != nullptr) { - parent->get_path_vec(output, timesteps); +std::vector PathTrie::get_path_vec() { + if (parent == nullptr) { + return std::vector{}; } + std::vector output_tokens=parent->get_path_vec(); if (character != ROOT_) { - output.push_back(character); - timesteps.push_back(timestep); + output_tokens.push_back(character); } + return output_tokens; } PathTrie* PathTrie::get_prev_grapheme(std::vector& output, - std::vector& timesteps, const Alphabet& alphabet) { PathTrie* stop = this; @@ -125,10 +120,9 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector& output, // Recursive call: recurse back until stop condition, then append data in // correct order as we walk back down the stack in the lines below. if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) { - stop = parent->get_prev_grapheme(output, timesteps, alphabet); + stop = parent->get_prev_grapheme(output, alphabet); } output.push_back(character); - timesteps.push_back(timestep); return stop; } @@ -147,7 +141,6 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte, } PathTrie* PathTrie::get_prev_word(std::vector& output, - std::vector& timesteps, const Alphabet& alphabet) { PathTrie* stop = this; @@ -157,10 +150,9 @@ PathTrie* PathTrie::get_prev_word(std::vector& output, // Recursive call: recurse back until stop condition, then append data in // correct order as we walk back down the stack in the lines below. if (parent != nullptr) { - stop = parent->get_prev_word(output, timesteps, alphabet); + stop = parent->get_prev_word(output, alphabet); } output.push_back(character); - timesteps.push_back(timestep); return stop; } @@ -173,6 +165,10 @@ void PathTrie::iterate_to_vec(std::vector& output) { log_prob_nb_cur = -NUM_FLT_INF; score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); + + timesteps = std::move(timesteps_cur); + timesteps_cur.clear(); + output.push_back(this); } for (auto child : children_) { @@ -229,8 +225,8 @@ void PathTrie::print(const Alphabet& a) { } } printf("\ntimesteps:\t "); - for (PathTrie* el : chain) { - printf("%d ", el->timestep); + for (unsigned int timestep : timesteps) { + printf("%d ", timestep); } printf("\n"); printf("transcript:\t %s\n", tr.c_str()); diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h index 0a4374fc..0e832f15 100644 --- a/native_client/ctcdecode/path_trie.h +++ b/native_client/ctcdecode/path_trie.h @@ -21,14 +21,13 @@ public: ~PathTrie(); // get new prefix after appending new char - PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true); + PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true); // get the prefix data in correct time order from root to current node - void get_path_vec(std::vector& output, std::vector& timesteps); + std::vector get_path_vec(); // get the prefix data in correct time order from beginning of last grapheme to current node PathTrie* get_prev_grapheme(std::vector& output, - std::vector& timesteps, const Alphabet& alphabet); // get the distance from current node to the first codepoint boundary, and the byte value at the boundary @@ -36,7 +35,6 @@ public: // get the prefix data in correct time order from beginning of last word to current node PathTrie* get_prev_word(std::vector& output, - std::vector& timesteps, const Alphabet& alphabet); // update log probs @@ -65,7 +63,10 @@ public: float score; float approx_ctc; unsigned int character; - unsigned int timestep; + std::vector timesteps; + // `timesteps_cur` is a temporary storage for each decoding step. + // At the end of a decoding step, it is moved to `timesteps`. + std::vector timesteps_cur; PathTrie* parent; private: diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index ad41dd8e..c637cdd6 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -293,12 +293,12 @@ std::vector Scorer::make_ngram(PathTrie* prefix) } std::vector prefix_vec; - std::vector prefix_steps; + std::vector prefix_steps = current_node->timesteps; if (is_utf8_mode_) { - new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_); + new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_); } else { - new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_); + new_node = current_node->get_prev_word(prefix_vec, alphabet_); } current_node = new_node->parent;