PR #3279 - use a tree structure to store timesteps
This commit is contained in:
parent
1f89bef5f0
commit
ec55597412
@ -37,6 +37,8 @@ DecoderState::init(const Alphabet& alphabet,
|
||||
prefix_root_.reset(root);
|
||||
prefixes_.push_back(root);
|
||||
|
||||
timestep_tree_root_=std::make_shared<TimestepTreeNode>(nullptr, 0);
|
||||
|
||||
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
||||
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
||||
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
||||
@ -96,7 +98,7 @@ DecoderState::next(const double *probs,
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
assert(prefix->is_empty() || !prefix->timesteps.empty());
|
||||
assert(prefix->is_empty() || prefix->timesteps!=nullptr);
|
||||
|
||||
// blank
|
||||
if (c == blank_id_) {
|
||||
@ -167,7 +169,7 @@ DecoderState::next(const double *probs,
|
||||
if (prefix_new->log_prob_nb_cur < log_p) {
|
||||
// record data needed to update timesteps
|
||||
// the actual update will be done if nothing better is found
|
||||
prefix_new->previous_timesteps = &prefix->timesteps;
|
||||
prefix_new->previous_timesteps = prefix->timesteps;
|
||||
prefix_new->new_timestep = abs_time_step_;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
@ -179,6 +181,11 @@ DecoderState::next(const double *probs,
|
||||
// update log probs
|
||||
prefixes_.clear();
|
||||
prefix_root_->iterate_to_vec(prefixes_);
|
||||
if (abs_time_step_ == 0) {
|
||||
for (PathTrie* prefix:prefixes_) {
|
||||
prefix->timesteps = timestep_tree_root_;
|
||||
}
|
||||
}
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes_.size() > beam_size_) {
|
||||
@ -234,7 +241,7 @@ DecoderState::decode(size_t num_results) const
|
||||
for (PathTrie* prefix : prefixes_copy) {
|
||||
Output output;
|
||||
prefix->get_path_vec(output.tokens);
|
||||
output.timesteps = prefix->timesteps;
|
||||
output.timesteps = get_history(prefix->timesteps);
|
||||
output.confidence = scores[prefix];
|
||||
outputs.push_back(output);
|
||||
if(outputs.size()>=num_returned) break;
|
||||
|
@ -21,6 +21,7 @@ class DecoderState {
|
||||
std::shared_ptr<Scorer> ext_scorer_;
|
||||
std::vector<PathTrie*> prefixes_;
|
||||
std::unique_ptr<PathTrie> prefix_root_;
|
||||
std::shared_ptr<TimestepTreeNode> timestep_tree_root_;
|
||||
|
||||
public:
|
||||
DecoderState() = default;
|
||||
|
@ -172,8 +172,16 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||
|
||||
if (previous_timesteps != nullptr) {
|
||||
timesteps = *previous_timesteps;
|
||||
timesteps.push_back(new_timestep);
|
||||
timesteps = nullptr;
|
||||
for (auto const& child : previous_timesteps->children) {
|
||||
if (child->data == new_timestep) {
|
||||
timesteps=child;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (timesteps == nullptr){
|
||||
timesteps = add_child(previous_timesteps, new_timestep);
|
||||
}
|
||||
}
|
||||
previous_timesteps=nullptr;
|
||||
|
||||
@ -230,7 +238,7 @@ void PathTrie::print(const Alphabet& a) {
|
||||
}
|
||||
}
|
||||
printf("\ntimesteps:\t ");
|
||||
for (unsigned int timestep : timesteps) {
|
||||
for (unsigned int timestep : get_history(timesteps)) {
|
||||
printf("%d ", timestep);
|
||||
}
|
||||
printf("\n");
|
||||
|
@ -10,6 +10,27 @@
|
||||
#include "fst/fstlib.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
/* Tree structure with parent and children information
|
||||
* It is used to store the timesteps data for the PathTrie below
|
||||
*/
|
||||
template<class DataT>
|
||||
struct TreeNode{
|
||||
std::shared_ptr<TreeNode<DataT>> parent;
|
||||
std::vector<std::shared_ptr<TreeNode<DataT>>> children;
|
||||
|
||||
DataT data;
|
||||
|
||||
TreeNode(std::shared_ptr<TreeNode<DataT>> const& parent_, DataT const& data_): parent{parent_}, data{data_} {}
|
||||
};
|
||||
|
||||
template<class NodeDataT, class ChildDataT>
|
||||
std::shared_ptr<TreeNode<NodeDataT>> add_child(std::shared_ptr<TreeNode<NodeDataT>> const& node, ChildDataT&& data_);
|
||||
|
||||
template<class DataT>
|
||||
std::vector<DataT> get_history(TreeNode<DataT>*);
|
||||
|
||||
using TimestepTreeNode=TreeNode<unsigned int>;
|
||||
|
||||
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||
* finite-state transducer for spelling correction.
|
||||
*/
|
||||
@ -63,10 +84,10 @@ public:
|
||||
float score;
|
||||
float approx_ctc;
|
||||
unsigned int character;
|
||||
std::vector<unsigned int> timesteps;
|
||||
std::shared_ptr<TimestepTreeNode> timesteps;
|
||||
|
||||
// timestep temporary storage for each decoding step.
|
||||
std::vector<unsigned int>* previous_timesteps=nullptr;
|
||||
std::shared_ptr<TimestepTreeNode> previous_timesteps=nullptr;
|
||||
unsigned int new_timestep;
|
||||
|
||||
PathTrie* parent;
|
||||
@ -84,4 +105,26 @@ private:
|
||||
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
||||
};
|
||||
|
||||
// TreeNode implementation
|
||||
template<class NodeDataT, class ChildDataT>
|
||||
std::shared_ptr<TreeNode<NodeDataT>> add_child(std::shared_ptr<TreeNode<NodeDataT>> const& node, ChildDataT&& data_){
|
||||
node->children.push_back(std::make_shared<TreeNode<NodeDataT>>(node, std::forward<ChildDataT>(data_)));
|
||||
return node->children.back();
|
||||
}
|
||||
|
||||
template<class DataT>
|
||||
void get_history_helper(std::shared_ptr<TreeNode<DataT>> const& tree_node, std::vector<DataT>* output){
|
||||
if(tree_node==nullptr) return;
|
||||
assert(tree_node->parent != tree_node);
|
||||
get_history_helper(tree_node->parent, output);
|
||||
output->push_back(tree_node->data);
|
||||
}
|
||||
template<class DataT>
|
||||
std::vector<DataT> get_history(std::shared_ptr<TreeNode<DataT>> const& tree_node){
|
||||
std::vector<DataT> output;
|
||||
get_history_helper(tree_node, &output);
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
#endif // PATH_TRIE_H
|
||||
|
@ -293,7 +293,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
||||
}
|
||||
|
||||
std::vector<unsigned int> prefix_vec;
|
||||
std::vector<unsigned int> prefix_steps = current_node->timesteps;
|
||||
|
||||
if (is_utf8_mode_) {
|
||||
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
|
||||
|
Loading…
Reference in New Issue
Block a user