Merge pull request #2302 from mozilla/issue2294

Only update time step of leaf prefixes (Fixes #2294)
This commit is contained in:
Reuben Morais 2019-08-20 12:04:35 +02:00 committed by GitHub
commit b25de5ac05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 16 deletions

View File

@ -42,7 +42,7 @@ decoder_init(const Alphabet &alphabet,
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
return state;
}
@ -57,7 +57,7 @@ decoder_next(const double *probs,
size_t beam_size,
Scorer *ext_scorer)
{
// prefix search over time
// prefix search over time
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
auto *prob = &probs[rel_time_step*class_dim];
@ -67,7 +67,7 @@ decoder_next(const double *probs,
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = state->prefixes[num_prefixes - 1]->score +
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
@ -75,7 +75,7 @@ decoder_next(const double *probs,
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
// loop over chars
// loop over class dim
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
@ -135,14 +135,14 @@ decoder_next(const double *probs,
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
} // end of loop over alphabet
// update log probs
state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes);
// only preserve top beam_size prefixes
if (state->prefixes.size() >= beam_size) {
if (state->prefixes.size() > beam_size) {
std::nth_element(state->prefixes.begin(),
state->prefixes.begin() + beam_size,
state->prefixes.end(),
@ -154,7 +154,6 @@ decoder_next(const double *probs,
// Remove the elements from std::vector
state->prefixes.resize(beam_size);
}
} // end of loop over time
}
@ -220,8 +219,8 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer)
{
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
@ -244,7 +243,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer)
{
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
// thread pool

View File

@ -39,7 +39,12 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
if (child->second->log_prob_c < cur_log_prob_c) {
// If existing child matches this new_char but had a lower probability,
// and it's a leaf, update its timestep to new_timestep.
// The leak check makes sure we don't update the child to have a later
// timestep than a grandchild.
if (child->second->log_prob_c < cur_log_prob_c &&
child->second->children_.size() == 0) {
child->second->log_prob_c = cur_log_prob_c;
child->second->timestep = new_timestep;
}
@ -54,7 +59,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
return child->second;
} else {
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
@ -145,9 +150,7 @@ void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
++child) {
for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) {
if (child->first == character) {
parent->children_.erase(child);
break;