Merge pull request #2165 from mozilla/keep-stream-absolute-timestep
Keep absolute per-stream time step in DecoderState (Fixes #2163)
This commit is contained in:
commit
6f8c902f25
@ -25,15 +25,16 @@ DecoderState* decoder_init(const Alphabet &alphabet,
|
||||
|
||||
// assign special ids
|
||||
DecoderState *state = new DecoderState;
|
||||
state->time_step = 0;
|
||||
state->space_id = alphabet.GetSpaceLabel();
|
||||
state->blank_id = alphabet.GetSize();
|
||||
|
||||
// init prefixes' root
|
||||
PathTrie *root = new PathTrie;
|
||||
root->score = root->log_prob_b_prev = 0.0;
|
||||
|
||||
|
||||
state->prefix_root = root;
|
||||
|
||||
|
||||
state->prefixes.push_back(root);
|
||||
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
@ -57,8 +58,8 @@ void decoder_next(const double *probs,
|
||||
Scorer *ext_scorer) {
|
||||
|
||||
// prefix search over time
|
||||
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
|
||||
auto *prob = &probs[time_step*class_dim];
|
||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||
auto *prob = &probs[rel_time_step*class_dim];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
@ -99,7 +100,7 @@ void decoder_next(const double *probs,
|
||||
}
|
||||
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
|
||||
auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
@ -6,6 +6,7 @@
|
||||
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
|
||||
|
||||
struct DecoderState {
|
||||
int time_step;
|
||||
int space_id;
|
||||
int blank_id;
|
||||
std::vector<PathTrie*> prefixes;
|
||||
|
Loading…
x
Reference in New Issue
Block a user