Merge pull request #2302 from mozilla/issue2294
Only update time step of leaf prefixes (Fixes #2294)
This commit is contained in:
commit
b25de5ac05
@ -42,7 +42,7 @@ decoder_init(const Alphabet &alphabet,
|
|||||||
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
|
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
root->set_matcher(matcher);
|
root->set_matcher(matcher);
|
||||||
}
|
}
|
||||||
|
|
||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ decoder_next(const double *probs,
|
|||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
Scorer *ext_scorer)
|
Scorer *ext_scorer)
|
||||||
{
|
{
|
||||||
// prefix search over time
|
// prefix search over time
|
||||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||||
auto *prob = &probs[rel_time_step*class_dim];
|
auto *prob = &probs[rel_time_step*class_dim];
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ decoder_next(const double *probs,
|
|||||||
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||||
std::sort(
|
std::sort(
|
||||||
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
|
||||||
min_cutoff = state->prefixes[num_prefixes - 1]->score +
|
min_cutoff = state->prefixes[num_prefixes - 1]->score +
|
||||||
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
|
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||||
full_beam = (num_prefixes == beam_size);
|
full_beam = (num_prefixes == beam_size);
|
||||||
@ -75,7 +75,7 @@ decoder_next(const double *probs,
|
|||||||
|
|
||||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||||
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
|
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
|
||||||
// loop over chars
|
// loop over class dim
|
||||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||||
auto c = log_prob_idx[index].first;
|
auto c = log_prob_idx[index].first;
|
||||||
auto log_prob_c = log_prob_idx[index].second;
|
auto log_prob_c = log_prob_idx[index].second;
|
||||||
@ -135,14 +135,14 @@ decoder_next(const double *probs,
|
|||||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
}
|
}
|
||||||
} // end of loop over prefix
|
} // end of loop over prefix
|
||||||
} // end of loop over vocabulary
|
} // end of loop over alphabet
|
||||||
|
|
||||||
// update log probs
|
// update log probs
|
||||||
state->prefixes.clear();
|
state->prefixes.clear();
|
||||||
state->prefix_root->iterate_to_vec(state->prefixes);
|
state->prefix_root->iterate_to_vec(state->prefixes);
|
||||||
|
|
||||||
// only preserve top beam_size prefixes
|
// only preserve top beam_size prefixes
|
||||||
if (state->prefixes.size() >= beam_size) {
|
if (state->prefixes.size() > beam_size) {
|
||||||
std::nth_element(state->prefixes.begin(),
|
std::nth_element(state->prefixes.begin(),
|
||||||
state->prefixes.begin() + beam_size,
|
state->prefixes.begin() + beam_size,
|
||||||
state->prefixes.end(),
|
state->prefixes.end(),
|
||||||
@ -154,7 +154,6 @@ decoder_next(const double *probs,
|
|||||||
// Remove the elements from std::vector
|
// Remove the elements from std::vector
|
||||||
state->prefixes.resize(beam_size);
|
state->prefixes.resize(beam_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end of loop over time
|
} // end of loop over time
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,8 +219,8 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
Scorer *ext_scorer) {
|
Scorer *ext_scorer)
|
||||||
|
{
|
||||||
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
|
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
|
||||||
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
|
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
|
||||||
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
|
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
|
||||||
@ -244,7 +243,8 @@ ctc_beam_search_decoder_batch(
|
|||||||
size_t num_processes,
|
size_t num_processes,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
Scorer *ext_scorer) {
|
Scorer *ext_scorer)
|
||||||
|
{
|
||||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||||
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
||||||
// thread pool
|
// thread pool
|
||||||
|
|||||||
@ -39,7 +39,12 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
|||||||
auto child = children_.begin();
|
auto child = children_.begin();
|
||||||
for (child = children_.begin(); child != children_.end(); ++child) {
|
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||||
if (child->first == new_char) {
|
if (child->first == new_char) {
|
||||||
if (child->second->log_prob_c < cur_log_prob_c) {
|
// If existing child matches this new_char but had a lower probability,
|
||||||
|
// and it's a leaf, update its timestep to new_timestep.
|
||||||
|
// The leak check makes sure we don't update the child to have a later
|
||||||
|
// timestep than a grandchild.
|
||||||
|
if (child->second->log_prob_c < cur_log_prob_c &&
|
||||||
|
child->second->children_.size() == 0) {
|
||||||
child->second->log_prob_c = cur_log_prob_c;
|
child->second->log_prob_c = cur_log_prob_c;
|
||||||
child->second->timestep = new_timestep;
|
child->second->timestep = new_timestep;
|
||||||
}
|
}
|
||||||
@ -54,7 +59,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
|||||||
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
||||||
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
||||||
}
|
}
|
||||||
return (child->second);
|
return child->second;
|
||||||
} else {
|
} else {
|
||||||
if (has_dictionary_) {
|
if (has_dictionary_) {
|
||||||
matcher_->SetState(dictionary_state_);
|
matcher_->SetState(dictionary_state_);
|
||||||
@ -145,9 +150,7 @@ void PathTrie::remove() {
|
|||||||
exists_ = false;
|
exists_ = false;
|
||||||
|
|
||||||
if (children_.size() == 0) {
|
if (children_.size() == 0) {
|
||||||
auto child = parent->children_.begin();
|
for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) {
|
||||||
for (child = parent->children_.begin(); child != parent->children_.end();
|
|
||||||
++child) {
|
|
||||||
if (child->first == character) {
|
if (child->first == character) {
|
||||||
parent->children_.erase(child);
|
parent->children_.erase(child);
|
||||||
break;
|
break;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user