Merge pull request #2165 from mozilla/keep-stream-absolute-timestep

Keep absolute per-stream time step in DecoderState (Fixes #2163)
This commit is contained in:
Reuben Morais 2019-06-11 11:23:04 -03:00 committed by GitHub
commit 6f8c902f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -25,6 +25,7 @@ DecoderState* decoder_init(const Alphabet &alphabet,
// assign special ids // assign special ids
DecoderState *state = new DecoderState; DecoderState *state = new DecoderState;
state->time_step = 0;
state->space_id = alphabet.GetSpaceLabel(); state->space_id = alphabet.GetSpaceLabel();
state->blank_id = alphabet.GetSize(); state->blank_id = alphabet.GetSize();
@ -57,8 +58,8 @@ void decoder_next(const double *probs,
Scorer *ext_scorer) { Scorer *ext_scorer) {
// prefix search over time // prefix search over time
for (size_t time_step = 0; time_step < time_dim; ++time_step) { for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
auto *prob = &probs[time_step*class_dim]; auto *prob = &probs[rel_time_step*class_dim];
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
@ -99,7 +100,7 @@ void decoder_next(const double *probs,
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c); auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c);
if (prefix_new != nullptr) { if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF; float log_p = -NUM_FLT_INF;

View File

@ -6,6 +6,7 @@
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */ /* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
struct DecoderState { struct DecoderState {
int time_step;
int space_id; int space_id;
int blank_id; int blank_id;
std::vector<PathTrie*> prefixes; std::vector<PathTrie*> prefixes;