PR #3279 - Fixed buggy timestep tree root
This commit is contained in:
parent
14bd9033d6
commit
1fa2e4ebcc
@ -35,6 +35,7 @@ DecoderState::init(const Alphabet& alphabet,
|
|||||||
PathTrie *root = new PathTrie;
|
PathTrie *root = new PathTrie;
|
||||||
root->score = root->log_prob_b_prev = 0.0;
|
root->score = root->log_prob_b_prev = 0.0;
|
||||||
prefix_root_.reset(root);
|
prefix_root_.reset(root);
|
||||||
|
prefix_root_->timesteps = ×tep_tree_root_;
|
||||||
prefixes_.push_back(root);
|
prefixes_.push_back(root);
|
||||||
|
|
||||||
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
||||||
@ -99,7 +100,7 @@ DecoderState::next(const double *probs,
|
|||||||
if (prefix->score == -NUM_FLT_INF) {
|
if (prefix->score == -NUM_FLT_INF) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
assert(prefix->is_empty() || prefix->timesteps != nullptr);
|
assert(prefix->timesteps != nullptr);
|
||||||
|
|
||||||
// blank
|
// blank
|
||||||
if (c == blank_id_) {
|
if (c == blank_id_) {
|
||||||
@ -182,11 +183,6 @@ DecoderState::next(const double *probs,
|
|||||||
// update log probs
|
// update log probs
|
||||||
prefixes_.clear();
|
prefixes_.clear();
|
||||||
prefix_root_->iterate_to_vec(prefixes_);
|
prefix_root_->iterate_to_vec(prefixes_);
|
||||||
if (abs_time_step_ == 0) {
|
|
||||||
for (PathTrie* prefix:prefixes_) {
|
|
||||||
prefix->timesteps = ×tep_tree_root_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// only preserve top beam_size prefixes
|
// only preserve top beam_size prefixes
|
||||||
if (prefixes_.size() > beam_size_) {
|
if (prefixes_.size() > beam_size_) {
|
||||||
@ -242,7 +238,8 @@ DecoderState::decode(size_t num_results) const
|
|||||||
for (PathTrie* prefix : prefixes_copy) {
|
for (PathTrie* prefix : prefixes_copy) {
|
||||||
Output output;
|
Output output;
|
||||||
prefix->get_path_vec(output.tokens);
|
prefix->get_path_vec(output.tokens);
|
||||||
output.timesteps = get_history(prefix->timesteps);
|
output.timesteps = get_history(prefix->timesteps, ×tep_tree_root_);
|
||||||
|
assert(output.tokens.size() == output.timesteps.size());
|
||||||
output.confidence = scores[prefix];
|
output.confidence = scores[prefix];
|
||||||
outputs.push_back(output);
|
outputs.push_back(output);
|
||||||
if (outputs.size() >= num_returned) break;
|
if (outputs.size() >= num_returned) break;
|
||||||
|
@ -28,7 +28,7 @@ template<class NodeDataT, class ChildDataT>
|
|||||||
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_);
|
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_);
|
||||||
|
|
||||||
template<class DataT>
|
template<class DataT>
|
||||||
std::vector<DataT> get_history(TreeNode<DataT>*);
|
std::vector<DataT> get_history(TreeNode<DataT> const*, TreeNode<DataT> const* = nullptr);
|
||||||
|
|
||||||
using TimestepTreeNode = TreeNode<unsigned int>;
|
using TimestepTreeNode = TreeNode<unsigned int>;
|
||||||
|
|
||||||
@ -115,16 +115,17 @@ TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* node, ChildDataT&& data_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<class DataT>
|
template<class DataT>
|
||||||
void get_history_helper(TreeNode<DataT>* tree_node, std::vector<DataT>* output) {
|
void get_history_helper(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root, std::vector<DataT>* output) {
|
||||||
if (tree_node == nullptr) return;
|
if (tree_node == root) return;
|
||||||
|
assert(tree_node != nullptr);
|
||||||
assert(tree_node->parent != tree_node);
|
assert(tree_node->parent != tree_node);
|
||||||
get_history_helper(tree_node->parent, output);
|
get_history_helper(tree_node->parent, root, output);
|
||||||
output->push_back(tree_node->data);
|
output->push_back(tree_node->data);
|
||||||
}
|
}
|
||||||
template<class DataT>
|
template<class DataT>
|
||||||
std::vector<DataT> get_history(TreeNode<DataT>* tree_node) {
|
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root) {
|
||||||
std::vector<DataT> output;
|
std::vector<DataT> output;
|
||||||
get_history_helper(tree_node, &output);
|
get_history_helper(tree_node, root, &output);
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user