From 1fa2e4ebccd6102b5e4ff837147854a2c0369f72 Mon Sep 17 00:00:00 2001 From: godeffroy Date: Tue, 15 Sep 2020 21:30:06 +0200 Subject: [PATCH] PR #3279 - Fixed buggy timestep tree root --- native_client/ctcdecode/ctc_beam_search_decoder.cpp | 11 ++++------- native_client/ctcdecode/path_trie.h | 13 +++++++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index b533d3dc..874dc7c1 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -35,6 +35,7 @@ DecoderState::init(const Alphabet& alphabet, PathTrie *root = new PathTrie; root->score = root->log_prob_b_prev = 0.0; prefix_root_.reset(root); + prefix_root_->timesteps = ×tep_tree_root_; prefixes_.push_back(root); if (ext_scorer && (bool)(ext_scorer_->dictionary)) { @@ -99,7 +100,7 @@ DecoderState::next(const double *probs, if (prefix->score == -NUM_FLT_INF) { continue; } - assert(prefix->is_empty() || prefix->timesteps != nullptr); + assert(prefix->timesteps != nullptr); // blank if (c == blank_id_) { @@ -182,11 +183,6 @@ DecoderState::next(const double *probs, // update log probs prefixes_.clear(); prefix_root_->iterate_to_vec(prefixes_); - if (abs_time_step_ == 0) { - for (PathTrie* prefix:prefixes_) { - prefix->timesteps = ×tep_tree_root_; - } - } // only preserve top beam_size prefixes if (prefixes_.size() > beam_size_) { @@ -242,7 +238,8 @@ DecoderState::decode(size_t num_results) const for (PathTrie* prefix : prefixes_copy) { Output output; prefix->get_path_vec(output.tokens); - output.timesteps = get_history(prefix->timesteps); + output.timesteps = get_history(prefix->timesteps, ×tep_tree_root_); + assert(output.tokens.size() == output.timesteps.size()); output.confidence = scores[prefix]; outputs.push_back(output); if (outputs.size() >= num_returned) break; diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h index 92dd288b..eac0d7b4 100644 --- a/native_client/ctcdecode/path_trie.h +++ b/native_client/ctcdecode/path_trie.h @@ -28,7 +28,7 @@ template TreeNode* add_child(TreeNode* node, ChildDataT&& data_); template -std::vector get_history(TreeNode*); +std::vector get_history(TreeNode const*, TreeNode const* = nullptr); using TimestepTreeNode = TreeNode; @@ -115,16 +115,17 @@ TreeNode* add_child(TreeNode* node, ChildDataT&& data_) { } template -void get_history_helper(TreeNode* tree_node, std::vector* output) { - if (tree_node == nullptr) return; +void get_history_helper(TreeNode const* tree_node, TreeNode const* root, std::vector* output) { + if (tree_node == root) return; + assert(tree_node != nullptr); assert(tree_node->parent != tree_node); - get_history_helper(tree_node->parent, output); + get_history_helper(tree_node->parent, root, output); output->push_back(tree_node->data); } template -std::vector get_history(TreeNode* tree_node) { +std::vector get_history(TreeNode const* tree_node, TreeNode const* root) { std::vector output; - get_history_helper(tree_node, &output); + get_history_helper(tree_node, root, &output); return output; }