Merge pull request #2915 from mozilla/delay-beam-expansion (Fixes #2867)

Delay beam expansion until a non-blank label has probability >0.1%
This commit is contained in:
Reuben Morais 2020-04-17 21:03:08 +02:00 committed by GitHub
commit a019e979ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 10 deletions

View File

@ -29,6 +29,7 @@ DecoderState::init(const Alphabet& alphabet,
cutoff_prob_ = cutoff_prob;
cutoff_top_n_ = cutoff_top_n;
ext_scorer_ = ext_scorer;
start_expanding_ = false;
// init prefixes' root
PathTrie *root = new PathTrie;
@ -56,6 +57,19 @@ DecoderState::next(const double *probs,
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) {
auto *prob = &probs[rel_time_step*class_dim];
// At the start of the decoding process, we delay beam expansion so that
// timings on the first letters is not incorrect. As soon as we see a
// timestep with blank probability lower than 0.999, we start expanding
// beams.
if (prob[blank_id_] < 0.999) {
start_expanding_ = true;
}
// If not expanding yet, just continue to next timestep.
if (!start_expanding_) {
continue;
}
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer_) {

View File

@ -16,6 +16,7 @@ class DecoderState {
size_t beam_size_;
double cutoff_prob_;
size_t cutoff_top_n_;
bool start_expanding_;
std::shared_ptr<Scorer> ext_scorer_;
std::vector<PathTrie*> prefixes_;

View File

@ -37,17 +37,8 @@ PathTrie::~PathTrie() {
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
for (; child != children_.end(); ++child) {
if (child->first == new_char) {
// If existing child matches this new_char but had a lower probability,
// and it's a leaf, update its timestep to new_timestep.
// The leak check makes sure we don't update the child to have a later
// timestep than a grandchild.
if (child->second->log_prob_c < cur_log_prob_c &&
child->second->children_.size() == 0) {
child->second->log_prob_c = cur_log_prob_c;
child->second->timestep = new_timestep;
}
break;
}
}