PR #3279 - use a tree structure to store timesteps

This commit is contained in:
godeffroy 2020-09-07 13:37:27 +02:00
parent 1f89bef5f0
commit ec55597412
5 changed files with 67 additions and 9 deletions

View File

@ -36,6 +36,8 @@ DecoderState::init(const Alphabet& alphabet,
root->score = root->log_prob_b_prev = 0.0;
prefix_root_.reset(root);
prefixes_.push_back(root);
timestep_tree_root_=std::make_shared<TimestepTreeNode>(nullptr, 0);
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
@ -96,7 +98,7 @@ DecoderState::next(const double *probs,
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
assert(prefix->is_empty() || !prefix->timesteps.empty());
assert(prefix->is_empty() || prefix->timesteps!=nullptr);
// blank
if (c == blank_id_) {
@ -167,7 +169,7 @@ DecoderState::next(const double *probs,
if (prefix_new->log_prob_nb_cur < log_p) {
// record data needed to update timesteps
// the actual update will be done if nothing better is found
prefix_new->previous_timesteps = &prefix->timesteps;
prefix_new->previous_timesteps = prefix->timesteps;
prefix_new->new_timestep = abs_time_step_;
}
prefix_new->log_prob_nb_cur =
@ -179,6 +181,11 @@ DecoderState::next(const double *probs,
// update log probs
prefixes_.clear();
prefix_root_->iterate_to_vec(prefixes_);
if (abs_time_step_ == 0) {
for (PathTrie* prefix:prefixes_) {
prefix->timesteps = timestep_tree_root_;
}
}
// only preserve top beam_size prefixes
if (prefixes_.size() > beam_size_) {
@ -234,7 +241,7 @@ DecoderState::decode(size_t num_results) const
for (PathTrie* prefix : prefixes_copy) {
Output output;
prefix->get_path_vec(output.tokens);
output.timesteps = prefix->timesteps;
output.timesteps = get_history(prefix->timesteps);
output.confidence = scores[prefix];
outputs.push_back(output);
if(outputs.size()>=num_returned) break;

View File

@ -21,6 +21,7 @@ class DecoderState {
std::shared_ptr<Scorer> ext_scorer_;
std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_;
std::shared_ptr<TimestepTreeNode> timestep_tree_root_;
public:
DecoderState() = default;

View File

@ -172,8 +172,16 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
if (previous_timesteps != nullptr) {
timesteps = *previous_timesteps;
timesteps.push_back(new_timestep);
timesteps = nullptr;
for (auto const& child : previous_timesteps->children) {
if (child->data == new_timestep) {
timesteps=child;
break;
}
}
if (timesteps == nullptr){
timesteps = add_child(previous_timesteps, new_timestep);
}
}
previous_timesteps=nullptr;
@ -230,7 +238,7 @@ void PathTrie::print(const Alphabet& a) {
}
}
printf("\ntimesteps:\t ");
for (unsigned int timestep : timesteps) {
for (unsigned int timestep : get_history(timesteps)) {
printf("%d ", timestep);
}
printf("\n");

View File

@ -10,6 +10,27 @@
#include "fst/fstlib.h"
#include "alphabet.h"
/* Tree structure with parent and children information
* It is used to store the timesteps data for the PathTrie below
*/
template<class DataT>
struct TreeNode{
std::shared_ptr<TreeNode<DataT>> parent;
std::vector<std::shared_ptr<TreeNode<DataT>>> children;
DataT data;
TreeNode(std::shared_ptr<TreeNode<DataT>> const& parent_, DataT const& data_): parent{parent_}, data{data_} {}
};
template<class NodeDataT, class ChildDataT>
std::shared_ptr<TreeNode<NodeDataT>> add_child(std::shared_ptr<TreeNode<NodeDataT>> const& node, ChildDataT&& data_);
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT>*);
using TimestepTreeNode=TreeNode<unsigned int>;
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
@ -63,10 +84,10 @@ public:
float score;
float approx_ctc;
unsigned int character;
std::vector<unsigned int> timesteps;
std::shared_ptr<TimestepTreeNode> timesteps;
// timestep temporary storage for each decoding step.
std::vector<unsigned int>* previous_timesteps=nullptr;
std::shared_ptr<TimestepTreeNode> previous_timesteps=nullptr;
unsigned int new_timestep;
PathTrie* parent;
@ -84,4 +105,26 @@ private:
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
};
// TreeNode implementation
template<class NodeDataT, class ChildDataT>
std::shared_ptr<TreeNode<NodeDataT>> add_child(std::shared_ptr<TreeNode<NodeDataT>> const& node, ChildDataT&& data_){
node->children.push_back(std::make_shared<TreeNode<NodeDataT>>(node, std::forward<ChildDataT>(data_)));
return node->children.back();
}
template<class DataT>
void get_history_helper(std::shared_ptr<TreeNode<DataT>> const& tree_node, std::vector<DataT>* output){
if(tree_node==nullptr) return;
assert(tree_node->parent != tree_node);
get_history_helper(tree_node->parent, output);
output->push_back(tree_node->data);
}
template<class DataT>
std::vector<DataT> get_history(std::shared_ptr<TreeNode<DataT>> const& tree_node){
std::vector<DataT> output;
get_history_helper(tree_node, &output);
return output;
}
#endif // PATH_TRIE_H

View File

@ -293,7 +293,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
}
std::vector<unsigned int> prefix_vec;
std::vector<unsigned int> prefix_steps = current_node->timesteps;
if (is_utf8_mode_) {
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);