Only update time step of leaf prefixes

The intention of this check is to improve the accuracy of the timings by recording the time step where the character saw its highest probability rather than the first time step where it was seen. The problem happens when updating the time step of a prefix that already has children. In that case, if any of the children have a time step that is earlier than `new_timestep`, it'll break the linearity of the timings. My fix is to simply check that the prefix we're updating is a leaf.

For example, say during decoding we have the following beams (format is `(char | time)`, tree node id below, nodes with same id are the same object):

```
1. (-1 | 0 ) -> ('s' | 10) -> ('h' | 13) -> ('e' | 14)
        A                B                  C                D

2. (-1 | 0 ) -> ('s' | 10) -> ('h' | 14)
        A                B                  E
```

And the prefix list is [B, C, D, E]. Currently, if we process character 'h' in time step 15 with a probability higher than both C and E, we update both nodes to have time step 15, which breaks linearity in beam 1. With my fix, we only update node E, which is a leaf. In my tests this does fix the problem, but since we don't have any known good quality data to verify against, it's hard to know if it has other side effects.
This commit is contained in:
Reuben Morais 2019-08-14 15:30:04 +02:00
parent 86fff2f660
commit e3bf5d3cc6
2 changed files with 19 additions and 16 deletions

View File

@ -42,7 +42,7 @@ decoder_init(const Alphabet &alphabet,
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
return state;
}
@ -57,7 +57,7 @@ decoder_next(const double *probs,
size_t beam_size,
Scorer *ext_scorer)
{
// prefix search over time
// prefix search over time
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
auto *prob = &probs[rel_time_step*class_dim];
@ -67,7 +67,7 @@ decoder_next(const double *probs,
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = state->prefixes[num_prefixes - 1]->score +
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
@ -75,7 +75,7 @@ decoder_next(const double *probs,
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
// loop over chars
// loop over class dim
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
@ -135,14 +135,14 @@ decoder_next(const double *probs,
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
} // end of loop over alphabet
// update log probs
state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes);
// only preserve top beam_size prefixes
if (state->prefixes.size() >= beam_size) {
if (state->prefixes.size() > beam_size) {
std::nth_element(state->prefixes.begin(),
state->prefixes.begin() + beam_size,
state->prefixes.end(),
@ -154,7 +154,6 @@ decoder_next(const double *probs,
// Remove the elements from std::vector
state->prefixes.resize(beam_size);
}
} // end of loop over time
}
@ -220,8 +219,8 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer)
{
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
@ -244,7 +243,8 @@ ctc_beam_search_decoder_batch(
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
Scorer *ext_scorer)
{
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
// thread pool

View File

@ -39,7 +39,12 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
if (child->second->log_prob_c < cur_log_prob_c) {
// If existing child matches this new_char but had a lower probability,
// and it's a leaf, update its timestep to new_timestep.
// The leak check makes sure we don't update the child to have a later
// timestep than a grandchild.
if (child->second->log_prob_c < cur_log_prob_c &&
child->second->children_.size() == 0) {
child->second->log_prob_c = cur_log_prob_c;
child->second->timestep = new_timestep;
}
@ -54,7 +59,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
return child->second;
} else {
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
@ -145,9 +150,7 @@ void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
++child) {
for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) {
if (child->first == character) {
parent->children_.erase(child);
break;