Merge pull request #2302 from mozilla/issue2294

Only update time step of leaf prefixes (Fixes #2294)
This commit is contained in:
Reuben Morais 2019-08-20 12:04:35 +02:00 committed by GitHub
commit b25de5ac05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 16 deletions

View File

@ -75,7 +75,7 @@ decoder_next(const double *probs,
std::vector<std::pair<size_t, float>> log_prob_idx = std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n); get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
// loop over chars // loop over class dim
for (size_t index = 0; index < log_prob_idx.size(); index++) { for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second; auto log_prob_c = log_prob_idx[index].second;
@ -135,14 +135,14 @@ decoder_next(const double *probs,
log_sum_exp(prefix_new->log_prob_nb_cur, log_p); log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
} }
} // end of loop over prefix } // end of loop over prefix
} // end of loop over vocabulary } // end of loop over alphabet
// update log probs // update log probs
state->prefixes.clear(); state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes); state->prefix_root->iterate_to_vec(state->prefixes);
// only preserve top beam_size prefixes // only preserve top beam_size prefixes
if (state->prefixes.size() >= beam_size) { if (state->prefixes.size() > beam_size) {
std::nth_element(state->prefixes.begin(), std::nth_element(state->prefixes.begin(),
state->prefixes.begin() + beam_size, state->prefixes.begin() + beam_size,
state->prefixes.end(), state->prefixes.end(),
@ -154,7 +154,6 @@ decoder_next(const double *probs,
// Remove the elements from std::vector // Remove the elements from std::vector
state->prefixes.resize(beam_size); state->prefixes.resize(beam_size);
} }
} // end of loop over time } // end of loop over time
} }
@ -220,8 +219,8 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer)
{
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer); DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer); decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer); std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
@ -244,7 +243,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer)
{
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element"); VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
// thread pool // thread pool

View File

@ -39,7 +39,12 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
auto child = children_.begin(); auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) { for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) { if (child->first == new_char) {
if (child->second->log_prob_c < cur_log_prob_c) { // If existing child matches this new_char but had a lower probability,
// and it's a leaf, update its timestep to new_timestep.
// The leak check makes sure we don't update the child to have a later
// timestep than a grandchild.
if (child->second->log_prob_c < cur_log_prob_c &&
child->second->children_.size() == 0) {
child->second->log_prob_c = cur_log_prob_c; child->second->log_prob_c = cur_log_prob_c;
child->second->timestep = new_timestep; child->second->timestep = new_timestep;
} }
@ -54,7 +59,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
child->second->log_prob_b_cur = -NUM_FLT_INF; child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF; child->second->log_prob_nb_cur = -NUM_FLT_INF;
} }
return (child->second); return child->second;
} else { } else {
if (has_dictionary_) { if (has_dictionary_) {
matcher_->SetState(dictionary_state_); matcher_->SetState(dictionary_state_);
@ -145,9 +150,7 @@ void PathTrie::remove() {
exists_ = false; exists_ = false;
if (children_.size() == 0) { if (children_.size() == 0) {
auto child = parent->children_.begin(); for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) {
for (child = parent->children_.begin(); child != parent->children_.end();
++child) {
if (child->first == character) { if (child->first == character) {
parent->children_.erase(child); parent->children_.erase(child);
break; break;