diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 2f604b31..d12e81fc 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -14,49 +14,61 @@ using FSTMATCH = fst::SortedMatcher; -std::vector ctc_beam_search_decoder( - const double *probs, - int time_dim, - int class_dim, - const Alphabet &alphabet, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer) { +DecoderState* decoder_init(const Alphabet &alphabet, + int class_dim, + Scorer* ext_scorer) { + // dimension check VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1, "The shape of probs does not match with " "the shape of the vocabulary"); // assign special ids - int space_id = alphabet.GetSpaceLabel(); - int blank_id = alphabet.GetSize(); + DecoderState *state = new DecoderState; + state->space_id = alphabet.GetSpaceLabel(); + state->blank_id = alphabet.GetSize(); // init prefixes' root - PathTrie root; - root.score = root.log_prob_b_prev = 0.0; - std::vector prefixes; - prefixes.push_back(&root); + PathTrie *root = new PathTrie; + root->score = root->log_prob_b_prev = 0.0; + + state->prefix_root = root; + + state->prefixes.push_back(root); if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { auto dict_ptr = ext_scorer->dictionary->Copy(true); - root.set_dictionary(dict_ptr); + root->set_dictionary(dict_ptr); auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); + root->set_matcher(matcher); } + + return state; +} - // prefix search over time +void decoder_next(const double *probs, + const Alphabet &alphabet, + DecoderState *state, + int time_dim, + int class_dim, + double cutoff_prob, + size_t cutoff_top_n, + size_t beam_size, + Scorer *ext_scorer) { + + // prefix search over time for (size_t time_step = 0; time_step < time_dim; ++time_step) { auto *prob = &probs[time_step*class_dim]; float min_cutoff = -NUM_FLT_INF; bool full_beam = false; if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); + size_t num_prefixes = std::min(state->prefixes.size(), beam_size); std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); + state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare); + + min_cutoff = state->prefixes[num_prefixes - 1]->score + + std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta); full_beam = (num_prefixes == beam_size); } @@ -67,22 +79,25 @@ std::vector ctc_beam_search_decoder( auto c = log_prob_idx[index].first; auto log_prob_c = log_prob_idx[index].second; - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; + for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) { + auto prefix = state->prefixes[i]; if (full_beam && log_prob_c + prefix->score < min_cutoff) { break; } + // blank - if (c == blank_id) { + if (c == state->blank_id) { prefix->log_prob_b_cur = log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); continue; } + // repeated character if (c == prefix->character) { prefix->log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); } + // get new prefix auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c); @@ -98,7 +113,7 @@ std::vector ctc_beam_search_decoder( // language model scoring if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { + (c == state->space_id || ext_scorer->is_character_based())) { PathTrie *prefix_to_score = nullptr; // skip scoring the space if (ext_scorer->is_character_based()) { @@ -114,34 +129,41 @@ std::vector ctc_beam_search_decoder( log_p += score; log_p += ext_scorer->beta; } + prefix_new->log_prob_nb_cur = log_sum_exp(prefix_new->log_prob_nb_cur, log_p); } } // end of loop over prefix } // end of loop over vocabulary - - - prefixes.clear(); + // update log probs - root.iterate_to_vec(prefixes); + state->prefixes.clear(); + state->prefix_root->iterate_to_vec(state->prefixes); // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), + if (state->prefixes.size() >= beam_size) { + std::nth_element(state->prefixes.begin(), + state->prefixes.begin() + beam_size, + state->prefixes.end(), prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); + for (size_t i = beam_size; i < state->prefixes.size(); ++i) { + state->prefixes[i]->remove(); } } + } // end of loop over time +} + +std::vector decoder_decode(DecoderState *state, + const Alphabet &alphabet, + size_t beam_size, + Scorer* ext_scorer) { // score the last word of each prefix that doesn't end with space if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { + for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) { + auto prefix = state->prefixes[i]; + if (!prefix->is_empty() && prefix->character != state->space_id) { float score = 0.0; std::vector ngram = ext_scorer->make_ngram(prefix); score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; @@ -151,17 +173,17 @@ std::vector ctc_beam_search_decoder( } } - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + size_t num_prefixes = std::min(state->prefixes.size(), beam_size); + std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare); // compute aproximate ctc score as the return score, without affecting the // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; + for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) { + double approx_ctc = state->prefixes[i]->score; if (ext_scorer != nullptr) { std::vector output; std::vector timesteps; - prefixes[i]->get_path_vec(output, timesteps); + state->prefixes[i]->get_path_vec(output, timesteps); auto prefix_length = output.size(); auto words = ext_scorer->split_labels(output); // remove word insert @@ -169,12 +191,30 @@ std::vector ctc_beam_search_decoder( // remove language model weight: approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; } - prefixes[i]->approx_ctc = approx_ctc; + state->prefixes[i]->approx_ctc = approx_ctc; } - return get_beam_search_result(prefixes, beam_size); + return get_beam_search_result(state->prefixes, beam_size); } +std::vector ctc_beam_search_decoder( + const double *probs, + int time_dim, + int class_dim, + const Alphabet &alphabet, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { + + DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer); + decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer); + std::vector out = decoder_decode(state, alphabet, beam_size, ext_scorer); + + delete state; + + return out; +} std::vector> ctc_beam_search_decoder_batch( diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index f2daa63e..81f1b613 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -7,12 +7,73 @@ #include "scorer.h" #include "output.h" #include "alphabet.h" +#include "decoderstate.h" -/* CTC Beam Search Decoder +/* Initialize CTC beam search decoder * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over alphabet of one time step. + * alphabet: The alphabet. + * class_dim: Alphabet length (plus 1 for space character). + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A struct containing prefixes and state variables. +*/ +DecoderState* decoder_init(const Alphabet &alphabet, + int class_dim, + Scorer *ext_scorer); + +/* Send data to the decoder + + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + * alphabet: The alphabet. + * state: The state structure previously obtained from decoder_init(). + * time_dim: Number of timesteps. + * class_dim: Alphabet length (plus 1 for space character). + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * beam_size: The width of beam search. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. +*/ +void decoder_next(const double *probs, + const Alphabet &alphabet, + DecoderState *state, + int time_dim, + int class_dim, + double cutoff_prob, + size_t cutoff_top_n, + size_t beam_size, + Scorer *ext_scorer); + +/* Get transcription for the data you sent via decoder_next() + + * Parameters: + * state: The state structure previously obtained from decoder_init(). + * alphabet: The alphabet. + * beam_size: The width of beam search. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A vector where each element is a pair of score and decoding result, + * in descending order. +*/ +std::vector decoder_decode(DecoderState *state, + const Alphabet &alphabet, + size_t beam_size, + Scorer* ext_scorer); + +/* CTC Beam Search Decoder + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + * time_dim: Number of timesteps. + * class_dim: Alphabet length (plus 1 for space character). * alphabet: The alphabet. * beam_size: The width of beam search. * cutoff_prob: Cutoff probability for pruning. @@ -21,8 +82,8 @@ * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. * Return: - * A vector that each element is a pair of score and decoding result, - * in desending order. + * A vector where each element is a pair of score and decoding result, + * in descending order. */ std::vector ctc_beam_search_decoder( @@ -36,9 +97,8 @@ std::vector ctc_beam_search_decoder( Scorer *ext_scorer); /* CTC Beam Search Decoder for batch data - * Parameters: - * probs_seq: 3-D vector that each element is a 2-D vector that can be used + * probs: 3-D vector where each element is a 2-D vector that can be used * by ctc_beam_search_decoder(). * alphabet: The alphabet. * beam_size: The width of beam search. @@ -49,7 +109,7 @@ std::vector ctc_beam_search_decoder( * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. * Return: - * A 2-D vector that each element is a vector of beam search decoding + * A 2-D vector where each element is a vector of beam search decoding * result for one audio sample. */ std::vector> diff --git a/native_client/ctcdecode/decoderstate.h b/native_client/ctcdecode/decoderstate.h new file mode 100644 index 00000000..751255bf --- /dev/null +++ b/native_client/ctcdecode/decoderstate.h @@ -0,0 +1,22 @@ +#ifndef DECODERSTATE_H_ +#define DECODERSTATE_H_ + +#include + +/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */ + +struct DecoderState { + int space_id; + int blank_id; + std::vector prefixes; + PathTrie *prefix_root; + + ~DecoderState() { + if (prefix_root != nullptr) { + delete prefix_root; + } + prefix_root = nullptr; + } +}; + +#endif // DECODERSTATE_H_ diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index fc0b2004..526c176c 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -75,13 +75,12 @@ using std::vector; API. When audio_buffer is full, features are computed from it and pushed to mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer. When batch_buffer is full, we do a single step through the acoustic model - and accumulate results in StreamingState::accumulated_logits. + and accumulate results in the DecoderState structure. - When fininshStream() is called, we decode the accumulated logits and return + When finishStream() is called, we decode the accumulated logits and return the corresponding transcription. */ struct StreamingState { - vector accumulated_logits; vector audio_buffer; vector mfcc_buffer; vector batch_buffer; @@ -113,6 +112,7 @@ struct ModelState { unsigned int ncontext; Alphabet* alphabet; Scorer* scorer; + DecoderState* decoder_state; unsigned int beam_width; unsigned int n_steps; unsigned int n_context; @@ -145,34 +145,26 @@ struct ModelState { * @brief Perform decoding of the logits, using basic CTC decoder or * CTC decoder with KenLM enabled * - * @param logits Flat matrix of logits, of size: - * n_frames * batch_size * num_classes - * * @return String representing the decoded text. */ - char* decode(const vector& logits); + char* decode(); /** * @brief Perform decoding of the logits, using basic CTC decoder or * CTC decoder with KenLM enabled * - * @param logits Flat matrix of logits, of size: - * n_frames * batch_size * num_classes - * * @return Vector of Output structs directly from the CTC decoder for additional processing. */ - vector decode_raw(const vector& logits); + vector decode_raw(); /** * @brief Return character-level metadata including letter timings. * - * @param logits Flat matrix of logits, of size: - * n_frames * batch_size * num_classes * * @return Metadata struct containing MetadataItem structs for each character. * The user is responsible for freeing Metadata by calling DS_FreeMetadata(). */ - Metadata* decode_metadata(const vector& logits); + Metadata* decode_metadata(); /** * @brief Do a single inference step in the acoustic model, with: @@ -202,6 +194,7 @@ ModelState::ModelState() , ncontext(0) , alphabet(nullptr) , scorer(nullptr) + , decoder_state(nullptr) , beam_width(0) , n_steps(-1) , n_context(-1) @@ -232,6 +225,11 @@ ModelState::~ModelState() delete scorer; delete alphabet; + + if (decoder_state != nullptr) { + delete decoder_state; + decoder_state = nullptr; + } } template @@ -270,21 +268,21 @@ StreamingState::feedAudioContent(const short* buffer, char* StreamingState::intermediateDecode() { - return model->decode(accumulated_logits); + return model->decode(); } char* StreamingState::finishStream() { finalizeStream(); - return model->decode(accumulated_logits); + return model->decode(); } Metadata* StreamingState::finishStreamWithMetadata() { finalizeStream(); - return model->decode_metadata(accumulated_logits); + return model->decode_metadata(); } void @@ -372,7 +370,26 @@ StreamingState::processMfccWindow(const vector& buf) void StreamingState::processBatch(const vector& buf, unsigned int n_steps) { - model->infer(buf.data(), n_steps, accumulated_logits); + vector logits; + model->infer(buf.data(), n_steps, logits); + + const int cutoff_top_n = 40; + const double cutoff_prob = 1.0; + const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank + const int n_frames = logits.size() / (BATCH_SIZE * num_classes); + + // Convert logits to double + vector inputs(logits.begin(), logits.end()); + + decoder_next(inputs.data(), + *model->alphabet, + model->decoder_state, + n_frames, + num_classes, + cutoff_prob, + cutoff_top_n, + model->beam_width, + model->scorer); } void @@ -507,35 +524,24 @@ ModelState::compute_mfcc(const vector& samples, vector& mfcc_outpu } char* -ModelState::decode(const vector& logits) +ModelState::decode() { - vector out = ModelState::decode_raw(logits); + vector out = ModelState::decode_raw(); return strdup(alphabet->LabelsToString(out[0].tokens).c_str()); } vector -ModelState::decode_raw(const vector& logits) +ModelState::decode_raw() { - const int cutoff_top_n = 40; - const double cutoff_prob = 1.0; - const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank - const int n_frames = logits.size() / (BATCH_SIZE * num_classes); - - // Convert logits to double - vector inputs(logits.begin(), logits.end()); - - // Vector of pairs - vector out = ctc_beam_search_decoder( - inputs.data(), n_frames, num_classes, *alphabet, beam_width, - cutoff_prob, cutoff_top_n, scorer); + vector out = decoder_decode(decoder_state, *alphabet, beam_width, scorer); return out; } Metadata* -ModelState::decode_metadata(const vector& logits) +ModelState::decode_metadata() { - vector out = decode_raw(logits); + vector out = decode_raw(); std::unique_ptr metadata(new Metadata()); metadata->num_items = out[0].tokens.size(); @@ -830,8 +836,6 @@ DS_SetupStream(ModelState* aCtx, aPreAllocFrames = 150; } - ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes); - ctx->audio_buffer.reserve(aCtx->audio_win_len); ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep); ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f); @@ -839,6 +843,9 @@ DS_SetupStream(ModelState* aCtx, ctx->model = aCtx; + DecoderState *params = decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer); + aCtx->decoder_state = params; + *retval = ctx.release(); return DS_ERR_OK; }