diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index a50f731f..35461c4e 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -25,15 +25,16 @@ DecoderState* decoder_init(const Alphabet &alphabet, // assign special ids DecoderState *state = new DecoderState; + state->time_step = 0; state->space_id = alphabet.GetSpaceLabel(); state->blank_id = alphabet.GetSize(); // init prefixes' root PathTrie *root = new PathTrie; root->score = root->log_prob_b_prev = 0.0; - + state->prefix_root = root; - + state->prefixes.push_back(root); if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { @@ -57,8 +58,8 @@ void decoder_next(const double *probs, Scorer *ext_scorer) { // prefix search over time - for (size_t time_step = 0; time_step < time_dim; ++time_step) { - auto *prob = &probs[time_step*class_dim]; + for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) { + auto *prob = &probs[rel_time_step*class_dim]; float min_cutoff = -NUM_FLT_INF; bool full_beam = false; @@ -99,7 +100,7 @@ void decoder_next(const double *probs, } // get new prefix - auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c); + auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c); if (prefix_new != nullptr) { float log_p = -NUM_FLT_INF; diff --git a/native_client/ctcdecode/decoderstate.h b/native_client/ctcdecode/decoderstate.h index 751255bf..1a92e80e 100644 --- a/native_client/ctcdecode/decoderstate.h +++ b/native_client/ctcdecode/decoderstate.h @@ -6,6 +6,7 @@ /* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */ struct DecoderState { + int time_step; int space_id; int blank_id; std::vector prefixes;