PR #3279 - Fixed buggy timestep tree root
This commit is contained in:
parent
14bd9033d6
commit
1fa2e4ebcc
@ -35,6 +35,7 @@ DecoderState::init(const Alphabet& alphabet,
|
||||
PathTrie *root = new PathTrie;
|
||||
root->score = root->log_prob_b_prev = 0.0;
|
||||
prefix_root_.reset(root);
|
||||
prefix_root_->timesteps = ×tep_tree_root_;
|
||||
prefixes_.push_back(root);
|
||||
|
||||
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
||||
@ -99,7 +100,7 @@ DecoderState::next(const double *probs,
|
||||
if (prefix->score == -NUM_FLT_INF) {
|
||||
continue;
|
||||
}
|
||||
assert(prefix->is_empty() || prefix->timesteps != nullptr);
|
||||
assert(prefix->timesteps != nullptr);
|
||||
|
||||
// blank
|
||||
if (c == blank_id_) {
|
||||
@ -182,11 +183,6 @@ DecoderState::next(const double *probs,
|
||||
// update log probs
|
||||
prefixes_.clear();
|
||||
prefix_root_->iterate_to_vec(prefixes_);
|
||||
if (abs_time_step_ == 0) {
|
||||
for (PathTrie* prefix:prefixes_) {
|
||||
prefix->timesteps = ×tep_tree_root_;
|
||||
}
|
||||
}
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes_.size() > beam_size_) {
|
||||
@ -242,7 +238,8 @@ DecoderState::decode(size_t num_results) const
|
||||
for (PathTrie* prefix : prefixes_copy) {
|
||||
Output output;
|
||||
prefix->get_path_vec(output.tokens);
|
||||
output.timesteps = get_history(prefix->timesteps);
|
||||
output.timesteps = get_history(prefix->timesteps, ×tep_tree_root_);
|
||||
assert(output.tokens.size() == output.timesteps.size());
|
||||
output.confidence = scores[prefix];
|
||||
outputs.push_back(output);
|
||||
if (outputs.size() >= num_returned) break;
|
||||
|
@ -28,7 +28,7 @@ template<class NodeDataT, class ChildDataT>
|
||||
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_);
|
||||
|
||||
template<class DataT>
|
||||
std::vector<DataT> get_history(TreeNode<DataT>*);
|
||||
std::vector<DataT> get_history(TreeNode<DataT> const*, TreeNode<DataT> const* = nullptr);
|
||||
|
||||
using TimestepTreeNode = TreeNode<unsigned int>;
|
||||
|
||||
@ -115,16 +115,17 @@ TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_) {
|
||||
}
|
||||
|
||||
template<class DataT>
|
||||
void get_history_helper(TreeNode<DataT>* tree_node, std::vector<DataT>* output) {
|
||||
if (tree_node == nullptr) return;
|
||||
void get_history_helper(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root, std::vector<DataT>* output) {
|
||||
if (tree_node == root) return;
|
||||
assert(tree_node != nullptr);
|
||||
assert(tree_node->parent != tree_node);
|
||||
get_history_helper(tree_node->parent, output);
|
||||
get_history_helper(tree_node->parent, root, output);
|
||||
output->push_back(tree_node->data);
|
||||
}
|
||||
template<class DataT>
|
||||
std::vector<DataT> get_history(TreeNode<DataT>* tree_node) {
|
||||
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root) {
|
||||
std::vector<DataT> output;
|
||||
get_history_helper(tree_node, &output);
|
||||
get_history_helper(tree_node, root, &output);
|
||||
return output;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user