Delay beam expansion until a non-blank label has probability >0.1%
This commit is contained in:
parent
7efdfc54a6
commit
33760a6bcd
@ -29,6 +29,7 @@ DecoderState::init(const Alphabet& alphabet,
|
||||
cutoff_prob_ = cutoff_prob;
|
||||
cutoff_top_n_ = cutoff_top_n;
|
||||
ext_scorer_ = ext_scorer;
|
||||
start_expanding_ = false;
|
||||
|
||||
// init prefixes' root
|
||||
PathTrie *root = new PathTrie;
|
||||
@ -56,6 +57,19 @@ DecoderState::next(const double *probs,
|
||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) {
|
||||
auto *prob = &probs[rel_time_step*class_dim];
|
||||
|
||||
// At the start of the decoding process, we delay beam expansion so that
|
||||
// timings on the first letters is not incorrect. As soon as we see a
|
||||
// timestep with blank probability lower than 0.999, we start expanding
|
||||
// beams.
|
||||
if (prob[blank_id_] < 0.999) {
|
||||
start_expanding_ = true;
|
||||
}
|
||||
|
||||
// If not expanding yet, just continue to next timestep.
|
||||
if (!start_expanding_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer_) {
|
||||
|
||||
@ -16,6 +16,7 @@ class DecoderState {
|
||||
size_t beam_size_;
|
||||
double cutoff_prob_;
|
||||
size_t cutoff_top_n_;
|
||||
bool start_expanding_;
|
||||
|
||||
std::shared_ptr<Scorer> ext_scorer_;
|
||||
std::vector<PathTrie*> prefixes_;
|
||||
|
||||
@ -37,17 +37,8 @@ PathTrie::~PathTrie() {
|
||||
|
||||
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
|
||||
auto child = children_.begin();
|
||||
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||
for (; child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
// If existing child matches this new_char but had a lower probability,
|
||||
// and it's a leaf, update its timestep to new_timestep.
|
||||
// The leak check makes sure we don't update the child to have a later
|
||||
// timestep than a grandchild.
|
||||
if (child->second->log_prob_c < cur_log_prob_c &&
|
||||
child->second->children_.size() == 0) {
|
||||
child->second->log_prob_c = cur_log_prob_c;
|
||||
child->second->timestep = new_timestep;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user