From ec55597412da8b43d079b650811321b82d60c588 Mon Sep 17 00:00:00 2001 From: godeffroy Date: Mon, 7 Sep 2020 13:37:27 +0200 Subject: [PATCH] PR #3279 - use a tree structure to store timesteps --- .../ctcdecode/ctc_beam_search_decoder.cpp | 13 +++-- .../ctcdecode/ctc_beam_search_decoder.h | 1 + native_client/ctcdecode/path_trie.cpp | 14 ++++-- native_client/ctcdecode/path_trie.h | 47 ++++++++++++++++++- native_client/ctcdecode/scorer.cpp | 1 - 5 files changed, 67 insertions(+), 9 deletions(-) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 0fe7784c..f317b0ec 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -36,6 +36,8 @@ DecoderState::init(const Alphabet& alphabet, root->score = root->log_prob_b_prev = 0.0; prefix_root_.reset(root); prefixes_.push_back(root); + + timestep_tree_root_=std::make_shared(nullptr, 0); if (ext_scorer && (bool)(ext_scorer_->dictionary)) { // no need for std::make_shared<>() since Copy() does 'new' behind the doors @@ -96,7 +98,7 @@ DecoderState::next(const double *probs, if (full_beam && log_prob_c + prefix->score < min_cutoff) { break; } - assert(prefix->is_empty() || !prefix->timesteps.empty()); + assert(prefix->is_empty() || prefix->timesteps!=nullptr); // blank if (c == blank_id_) { @@ -167,7 +169,7 @@ DecoderState::next(const double *probs, if (prefix_new->log_prob_nb_cur < log_p) { // record data needed to update timesteps // the actual update will be done if nothing better is found - prefix_new->previous_timesteps = &prefix->timesteps; + prefix_new->previous_timesteps = prefix->timesteps; prefix_new->new_timestep = abs_time_step_; } prefix_new->log_prob_nb_cur = @@ -179,6 +181,11 @@ DecoderState::next(const double *probs, // update log probs prefixes_.clear(); prefix_root_->iterate_to_vec(prefixes_); + if (abs_time_step_ == 0) { + for (PathTrie* prefix:prefixes_) { + prefix->timesteps = timestep_tree_root_; + } + } // only preserve top beam_size prefixes if (prefixes_.size() > beam_size_) { @@ -234,7 +241,7 @@ DecoderState::decode(size_t num_results) const for (PathTrie* prefix : prefixes_copy) { Output output; prefix->get_path_vec(output.tokens); - output.timesteps = prefix->timesteps; + output.timesteps = get_history(prefix->timesteps); output.confidence = scores[prefix]; outputs.push_back(output); if(outputs.size()>=num_returned) break; diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index 73bdfcc7..e8324f5e 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -21,6 +21,7 @@ class DecoderState { std::shared_ptr ext_scorer_; std::vector prefixes_; std::unique_ptr prefix_root_; + std::shared_ptr timestep_tree_root_; public: DecoderState() = default; diff --git a/native_client/ctcdecode/path_trie.cpp b/native_client/ctcdecode/path_trie.cpp index cca89a93..3f9ab91f 100644 --- a/native_client/ctcdecode/path_trie.cpp +++ b/native_client/ctcdecode/path_trie.cpp @@ -172,8 +172,16 @@ void PathTrie::iterate_to_vec(std::vector& output) { score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); if (previous_timesteps != nullptr) { - timesteps = *previous_timesteps; - timesteps.push_back(new_timestep); + timesteps = nullptr; + for (auto const& child : previous_timesteps->children) { + if (child->data == new_timestep) { + timesteps=child; + break; + } + } + if (timesteps == nullptr){ + timesteps = add_child(previous_timesteps, new_timestep); + } } previous_timesteps=nullptr; @@ -230,7 +238,7 @@ void PathTrie::print(const Alphabet& a) { } } printf("\ntimesteps:\t "); - for (unsigned int timestep : timesteps) { + for (unsigned int timestep : get_history(timesteps)) { printf("%d ", timestep); } printf("\n"); diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h index 00e0c3bf..8c3ae0a7 100644 --- a/native_client/ctcdecode/path_trie.h +++ b/native_client/ctcdecode/path_trie.h @@ -10,6 +10,27 @@ #include "fst/fstlib.h" #include "alphabet.h" +/* Tree structure with parent and children information + * It is used to store the timesteps data for the PathTrie below + */ +template +struct TreeNode{ + std::shared_ptr> parent; + std::vector>> children; + + DataT data; + + TreeNode(std::shared_ptr> const& parent_, DataT const& data_): parent{parent_}, data{data_} {} +}; + +template +std::shared_ptr> add_child(std::shared_ptr> const& node, ChildDataT&& data_); + +template +std::vector get_history(TreeNode*); + +using TimestepTreeNode=TreeNode; + /* Trie tree for prefix storing and manipulating, with a dictionary in * finite-state transducer for spelling correction. */ @@ -63,10 +84,10 @@ public: float score; float approx_ctc; unsigned int character; - std::vector timesteps; + std::shared_ptr timesteps; // timestep temporary storage for each decoding step. - std::vector* previous_timesteps=nullptr; + std::shared_ptr previous_timesteps=nullptr; unsigned int new_timestep; PathTrie* parent; @@ -84,4 +105,26 @@ private: std::shared_ptr> matcher_; }; +// TreeNode implementation +template +std::shared_ptr> add_child(std::shared_ptr> const& node, ChildDataT&& data_){ + node->children.push_back(std::make_shared>(node, std::forward(data_))); + return node->children.back(); +} + +template +void get_history_helper(std::shared_ptr> const& tree_node, std::vector* output){ + if(tree_node==nullptr) return; + assert(tree_node->parent != tree_node); + get_history_helper(tree_node->parent, output); + output->push_back(tree_node->data); +} +template +std::vector get_history(std::shared_ptr> const& tree_node){ + std::vector output; + get_history_helper(tree_node, &output); + return output; +} + + #endif // PATH_TRIE_H diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index c637cdd6..7c86f9d5 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -293,7 +293,6 @@ std::vector Scorer::make_ngram(PathTrie* prefix) } std::vector prefix_vec; - std::vector prefix_steps = current_node->timesteps; if (is_utf8_mode_) { new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);