Merge pull request #2165 from mozilla/keep-stream-absolute-timestep
Keep absolute per-stream time step in DecoderState (Fixes #2163)
This commit is contained in:
commit
6f8c902f25
@ -25,6 +25,7 @@ DecoderState* decoder_init(const Alphabet &alphabet,
|
|||||||
|
|
||||||
// assign special ids
|
// assign special ids
|
||||||
DecoderState *state = new DecoderState;
|
DecoderState *state = new DecoderState;
|
||||||
|
state->time_step = 0;
|
||||||
state->space_id = alphabet.GetSpaceLabel();
|
state->space_id = alphabet.GetSpaceLabel();
|
||||||
state->blank_id = alphabet.GetSize();
|
state->blank_id = alphabet.GetSize();
|
||||||
|
|
||||||
@ -57,8 +58,8 @@ void decoder_next(const double *probs,
|
|||||||
Scorer *ext_scorer) {
|
Scorer *ext_scorer) {
|
||||||
|
|
||||||
// prefix search over time
|
// prefix search over time
|
||||||
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
|
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||||
auto *prob = &probs[time_step*class_dim];
|
auto *prob = &probs[rel_time_step*class_dim];
|
||||||
|
|
||||||
float min_cutoff = -NUM_FLT_INF;
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
bool full_beam = false;
|
bool full_beam = false;
|
||||||
@ -99,7 +100,7 @@ void decoder_next(const double *probs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get new prefix
|
// get new prefix
|
||||||
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
|
auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c);
|
||||||
|
|
||||||
if (prefix_new != nullptr) {
|
if (prefix_new != nullptr) {
|
||||||
float log_p = -NUM_FLT_INF;
|
float log_p = -NUM_FLT_INF;
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
|
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
|
||||||
|
|
||||||
struct DecoderState {
|
struct DecoderState {
|
||||||
|
int time_step;
|
||||||
int space_id;
|
int space_id;
|
||||||
int blank_id;
|
int blank_id;
|
||||||
std::vector<PathTrie*> prefixes;
|
std::vector<PathTrie*> prefixes;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user