Merge pull request #2165 from mozilla/keep-stream-absolute-timestep

Keep absolute per-stream time step in DecoderState (Fixes #2163)
This commit is contained in:
Reuben Morais 2019-06-11 11:23:04 -03:00 committed by GitHub
commit 6f8c902f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -25,15 +25,16 @@ DecoderState* decoder_init(const Alphabet &alphabet,
// assign special ids
DecoderState *state = new DecoderState;
state->time_step = 0;
state->space_id = alphabet.GetSpaceLabel();
state->blank_id = alphabet.GetSize();
// init prefixes' root
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;
state->prefix_root = root;
state->prefixes.push_back(root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
@ -57,8 +58,8 @@ void decoder_next(const double *probs,
Scorer *ext_scorer) {
// prefix search over time
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto *prob = &probs[time_step*class_dim];
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
auto *prob = &probs[rel_time_step*class_dim];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
@ -99,7 +100,7 @@ void decoder_next(const double *probs,
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;

View File

@ -6,6 +6,7 @@
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
struct DecoderState {
int time_step;
int space_id;
int blank_id;
std::vector<PathTrie*> prefixes;