PR #3279 - Fixed buggy timestep tree root

This commit is contained in:
godeffroy 2020-09-15 21:30:06 +02:00
parent 14bd9033d6
commit 1fa2e4ebcc
2 changed files with 11 additions and 13 deletions

View File

@ -35,6 +35,7 @@ DecoderState::init(const Alphabet& alphabet,
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;
prefix_root_.reset(root);
prefix_root_->timesteps = &timestep_tree_root_;
prefixes_.push_back(root);
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
@ -99,7 +100,7 @@ DecoderState::next(const double *probs,
if (prefix->score == -NUM_FLT_INF) {
continue;
}
assert(prefix->is_empty() || prefix->timesteps != nullptr);
assert(prefix->timesteps != nullptr);
// blank
if (c == blank_id_) {
@ -182,11 +183,6 @@ DecoderState::next(const double *probs,
// update log probs
prefixes_.clear();
prefix_root_->iterate_to_vec(prefixes_);
if (abs_time_step_ == 0) {
for (PathTrie* prefix:prefixes_) {
prefix->timesteps = &timestep_tree_root_;
}
}
// only preserve top beam_size prefixes
if (prefixes_.size() > beam_size_) {
@ -242,7 +238,8 @@ DecoderState::decode(size_t num_results) const
for (PathTrie* prefix : prefixes_copy) {
Output output;
prefix->get_path_vec(output.tokens);
output.timesteps = get_history(prefix->timesteps);
output.timesteps = get_history(prefix->timesteps, &timestep_tree_root_);
assert(output.tokens.size() == output.timesteps.size());
output.confidence = scores[prefix];
outputs.push_back(output);
if (outputs.size() >= num_returned) break;

View File

@ -28,7 +28,7 @@ template<class NodeDataT, class ChildDataT>
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_);
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT>*);
std::vector<DataT> get_history(TreeNode<DataT> const*, TreeNode<DataT> const* = nullptr);
using TimestepTreeNode = TreeNode<unsigned int>;
@ -115,16 +115,17 @@ TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_) {
}
template<class DataT>
void get_history_helper(TreeNode<DataT>* tree_node, std::vector<DataT>* output) {
if (tree_node == nullptr) return;
void get_history_helper(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root, std::vector<DataT>* output) {
if (tree_node == root) return;
assert(tree_node != nullptr);
assert(tree_node->parent != tree_node);
get_history_helper(tree_node->parent, output);
get_history_helper(tree_node->parent, root, output);
output->push_back(tree_node->data);
}
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT>* tree_node) {
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root) {
std::vector<DataT> output;
get_history_helper(tree_node, &output);
get_history_helper(tree_node, root, &output);
return output;
}