Merge pull request #2302 from mozilla/issue2294
Only update time step of leaf prefixes (Fixes #2294)
This commit is contained in:
commit
b25de5ac05
@ -42,7 +42,7 @@ decoder_init(const Alphabet &alphabet,
|
||||
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root->set_matcher(matcher);
|
||||
}
|
||||
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ decoder_next(const double *probs,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer)
|
||||
{
|
||||
// prefix search over time
|
||||
// prefix search over time
|
||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||
auto *prob = &probs[rel_time_step*class_dim];
|
||||
|
||||
@ -67,7 +67,7 @@ decoder_next(const double *probs,
|
||||
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||
std::sort(
|
||||
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||
|
||||
|
||||
min_cutoff = state->prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
@ -75,7 +75,7 @@ decoder_next(const double *probs,
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
|
||||
// loop over chars
|
||||
// loop over class dim
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
auto log_prob_c = log_prob_idx[index].second;
|
||||
@ -135,14 +135,14 @@ decoder_next(const double *probs,
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
} // end of loop over vocabulary
|
||||
|
||||
} // end of loop over alphabet
|
||||
|
||||
// update log probs
|
||||
state->prefixes.clear();
|
||||
state->prefix_root->iterate_to_vec(state->prefixes);
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (state->prefixes.size() >= beam_size) {
|
||||
if (state->prefixes.size() > beam_size) {
|
||||
std::nth_element(state->prefixes.begin(),
|
||||
state->prefixes.begin() + beam_size,
|
||||
state->prefixes.end(),
|
||||
@ -154,7 +154,6 @@ decoder_next(const double *probs,
|
||||
// Remove the elements from std::vector
|
||||
state->prefixes.resize(beam_size);
|
||||
}
|
||||
|
||||
} // end of loop over time
|
||||
}
|
||||
|
||||
@ -220,8 +219,8 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
|
||||
Scorer *ext_scorer)
|
||||
{
|
||||
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
|
||||
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
|
||||
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
|
||||
@ -244,7 +243,8 @@ ctc_beam_search_decoder_batch(
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
Scorer *ext_scorer)
|
||||
{
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
||||
// thread pool
|
||||
|
@ -39,7 +39,12 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
||||
auto child = children_.begin();
|
||||
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
if (child->second->log_prob_c < cur_log_prob_c) {
|
||||
// If existing child matches this new_char but had a lower probability,
|
||||
// and it's a leaf, update its timestep to new_timestep.
|
||||
// The leak check makes sure we don't update the child to have a later
|
||||
// timestep than a grandchild.
|
||||
if (child->second->log_prob_c < cur_log_prob_c &&
|
||||
child->second->children_.size() == 0) {
|
||||
child->second->log_prob_c = cur_log_prob_c;
|
||||
child->second->timestep = new_timestep;
|
||||
}
|
||||
@ -54,7 +59,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
||||
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
||||
}
|
||||
return (child->second);
|
||||
return child->second;
|
||||
} else {
|
||||
if (has_dictionary_) {
|
||||
matcher_->SetState(dictionary_state_);
|
||||
@ -145,9 +150,7 @@ void PathTrie::remove() {
|
||||
exists_ = false;
|
||||
|
||||
if (children_.size() == 0) {
|
||||
auto child = parent->children_.begin();
|
||||
for (child = parent->children_.begin(); child != parent->children_.end();
|
||||
++child) {
|
||||
for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) {
|
||||
if (child->first == character) {
|
||||
parent->children_.erase(child);
|
||||
break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user