PR #3279 - use a tree structure to store timesteps
This commit is contained in:
parent
1f89bef5f0
commit
ec55597412
@ -37,6 +37,8 @@ DecoderState::init(const Alphabet& alphabet,
|
|||||||
prefix_root_.reset(root);
|
prefix_root_.reset(root);
|
||||||
prefixes_.push_back(root);
|
prefixes_.push_back(root);
|
||||||
|
|
||||||
|
timestep_tree_root_=std::make_shared<TimestepTreeNode>(nullptr, 0);
|
||||||
|
|
||||||
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
||||||
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
||||||
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
||||||
@ -96,7 +98,7 @@ DecoderState::next(const double *probs,
|
|||||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
assert(prefix->is_empty() || !prefix->timesteps.empty());
|
assert(prefix->is_empty() || prefix->timesteps!=nullptr);
|
||||||
|
|
||||||
// blank
|
// blank
|
||||||
if (c == blank_id_) {
|
if (c == blank_id_) {
|
||||||
@ -167,7 +169,7 @@ DecoderState::next(const double *probs,
|
|||||||
if (prefix_new->log_prob_nb_cur < log_p) {
|
if (prefix_new->log_prob_nb_cur < log_p) {
|
||||||
// record data needed to update timesteps
|
// record data needed to update timesteps
|
||||||
// the actual update will be done if nothing better is found
|
// the actual update will be done if nothing better is found
|
||||||
prefix_new->previous_timesteps = &prefix->timesteps;
|
prefix_new->previous_timesteps = prefix->timesteps;
|
||||||
prefix_new->new_timestep = abs_time_step_;
|
prefix_new->new_timestep = abs_time_step_;
|
||||||
}
|
}
|
||||||
prefix_new->log_prob_nb_cur =
|
prefix_new->log_prob_nb_cur =
|
||||||
@ -179,6 +181,11 @@ DecoderState::next(const double *probs,
|
|||||||
// update log probs
|
// update log probs
|
||||||
prefixes_.clear();
|
prefixes_.clear();
|
||||||
prefix_root_->iterate_to_vec(prefixes_);
|
prefix_root_->iterate_to_vec(prefixes_);
|
||||||
|
if (abs_time_step_ == 0) {
|
||||||
|
for (PathTrie* prefix:prefixes_) {
|
||||||
|
prefix->timesteps = timestep_tree_root_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// only preserve top beam_size prefixes
|
// only preserve top beam_size prefixes
|
||||||
if (prefixes_.size() > beam_size_) {
|
if (prefixes_.size() > beam_size_) {
|
||||||
@ -234,7 +241,7 @@ DecoderState::decode(size_t num_results) const
|
|||||||
for (PathTrie* prefix : prefixes_copy) {
|
for (PathTrie* prefix : prefixes_copy) {
|
||||||
Output output;
|
Output output;
|
||||||
prefix->get_path_vec(output.tokens);
|
prefix->get_path_vec(output.tokens);
|
||||||
output.timesteps = prefix->timesteps;
|
output.timesteps = get_history(prefix->timesteps);
|
||||||
output.confidence = scores[prefix];
|
output.confidence = scores[prefix];
|
||||||
outputs.push_back(output);
|
outputs.push_back(output);
|
||||||
if(outputs.size()>=num_returned) break;
|
if(outputs.size()>=num_returned) break;
|
||||||
|
@ -21,6 +21,7 @@ class DecoderState {
|
|||||||
std::shared_ptr<Scorer> ext_scorer_;
|
std::shared_ptr<Scorer> ext_scorer_;
|
||||||
std::vector<PathTrie*> prefixes_;
|
std::vector<PathTrie*> prefixes_;
|
||||||
std::unique_ptr<PathTrie> prefix_root_;
|
std::unique_ptr<PathTrie> prefix_root_;
|
||||||
|
std::shared_ptr<TimestepTreeNode> timestep_tree_root_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DecoderState() = default;
|
DecoderState() = default;
|
||||||
|
@ -172,8 +172,16 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
|||||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||||
|
|
||||||
if (previous_timesteps != nullptr) {
|
if (previous_timesteps != nullptr) {
|
||||||
timesteps = *previous_timesteps;
|
timesteps = nullptr;
|
||||||
timesteps.push_back(new_timestep);
|
for (auto const& child : previous_timesteps->children) {
|
||||||
|
if (child->data == new_timestep) {
|
||||||
|
timesteps=child;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (timesteps == nullptr){
|
||||||
|
timesteps = add_child(previous_timesteps, new_timestep);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
previous_timesteps=nullptr;
|
previous_timesteps=nullptr;
|
||||||
|
|
||||||
@ -230,7 +238,7 @@ void PathTrie::print(const Alphabet& a) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\ntimesteps:\t ");
|
printf("\ntimesteps:\t ");
|
||||||
for (unsigned int timestep : timesteps) {
|
for (unsigned int timestep : get_history(timesteps)) {
|
||||||
printf("%d ", timestep);
|
printf("%d ", timestep);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
@ -10,6 +10,27 @@
|
|||||||
#include "fst/fstlib.h"
|
#include "fst/fstlib.h"
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
|
||||||
|
/* Tree structure with parent and children information
|
||||||
|
* It is used to store the timesteps data for the PathTrie below
|
||||||
|
*/
|
||||||
|
template<class DataT>
|
||||||
|
struct TreeNode{
|
||||||
|
std::shared_ptr<TreeNode<DataT>> parent;
|
||||||
|
std::vector<std::shared_ptr<TreeNode<DataT>>> children;
|
||||||
|
|
||||||
|
DataT data;
|
||||||
|
|
||||||
|
TreeNode(std::shared_ptr<TreeNode<DataT>> const& parent_, DataT const& data_): parent{parent_}, data{data_} {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class NodeDataT, class ChildDataT>
|
||||||
|
std::shared_ptr<TreeNode<NodeDataT>> add_child(std::shared_ptr<TreeNode<NodeDataT>> const& node, ChildDataT&& data_);
|
||||||
|
|
||||||
|
template<class DataT>
|
||||||
|
std::vector<DataT> get_history(TreeNode<DataT>*);
|
||||||
|
|
||||||
|
using TimestepTreeNode=TreeNode<unsigned int>;
|
||||||
|
|
||||||
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||||
* finite-state transducer for spelling correction.
|
* finite-state transducer for spelling correction.
|
||||||
*/
|
*/
|
||||||
@ -63,10 +84,10 @@ public:
|
|||||||
float score;
|
float score;
|
||||||
float approx_ctc;
|
float approx_ctc;
|
||||||
unsigned int character;
|
unsigned int character;
|
||||||
std::vector<unsigned int> timesteps;
|
std::shared_ptr<TimestepTreeNode> timesteps;
|
||||||
|
|
||||||
// timestep temporary storage for each decoding step.
|
// timestep temporary storage for each decoding step.
|
||||||
std::vector<unsigned int>* previous_timesteps=nullptr;
|
std::shared_ptr<TimestepTreeNode> previous_timesteps=nullptr;
|
||||||
unsigned int new_timestep;
|
unsigned int new_timestep;
|
||||||
|
|
||||||
PathTrie* parent;
|
PathTrie* parent;
|
||||||
@ -84,4 +105,26 @@ private:
|
|||||||
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TreeNode implementation
|
||||||
|
template<class NodeDataT, class ChildDataT>
|
||||||
|
std::shared_ptr<TreeNode<NodeDataT>> add_child(std::shared_ptr<TreeNode<NodeDataT>> const& node, ChildDataT&& data_){
|
||||||
|
node->children.push_back(std::make_shared<TreeNode<NodeDataT>>(node, std::forward<ChildDataT>(data_)));
|
||||||
|
return node->children.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class DataT>
|
||||||
|
void get_history_helper(std::shared_ptr<TreeNode<DataT>> const& tree_node, std::vector<DataT>* output){
|
||||||
|
if(tree_node==nullptr) return;
|
||||||
|
assert(tree_node->parent != tree_node);
|
||||||
|
get_history_helper(tree_node->parent, output);
|
||||||
|
output->push_back(tree_node->data);
|
||||||
|
}
|
||||||
|
template<class DataT>
|
||||||
|
std::vector<DataT> get_history(std::shared_ptr<TreeNode<DataT>> const& tree_node){
|
||||||
|
std::vector<DataT> output;
|
||||||
|
get_history_helper(tree_node, &output);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // PATH_TRIE_H
|
#endif // PATH_TRIE_H
|
||||||
|
@ -293,7 +293,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<unsigned int> prefix_vec;
|
std::vector<unsigned int> prefix_vec;
|
||||||
std::vector<unsigned int> prefix_steps = current_node->timesteps;
|
|
||||||
|
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
|
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
|
||||||
|
Loading…
Reference in New Issue
Block a user