From 440893c58d668cc0d975c48b5fc3971feaf9c7c3 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Fri, 2 Nov 2018 13:43:16 -0300 Subject: [PATCH] Simplify ctcdecode API signatures to avoid nested STL structures Flatten structure to avoid nested STL structures which are awkward to wrap with SWIG and slower at runtime. --- .../ctcdecode/ctc_beam_search_decoder.cpp | 42 ++++++++++--------- .../ctcdecode/ctc_beam_search_decoder.h | 38 ++++++++++------- native_client/ctcdecode/decoder_utils.cpp | 24 +++++------ native_client/ctcdecode/decoder_utils.h | 7 +++- native_client/ctcdecode/output.h | 6 ++- native_client/deepspeech.cc | 18 ++++---- 6 files changed, 72 insertions(+), 63 deletions(-) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 2df2eb10..aa8ff75d 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -14,21 +14,19 @@ using FSTMATCH = fst::SortedMatcher; -std::vector> ctc_beam_search_decoder( - const std::vector> &probs_seq, +std::vector ctc_beam_search_decoder( + const double *probs, + int time_dim, + int class_dim, const Alphabet &alphabet, size_t beam_size, double cutoff_prob, size_t cutoff_top_n, Scorer *ext_scorer) { // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - alphabet.GetSize()+1, - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } + VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1, + "The shape of probs does not match with " + "the shape of the vocabulary"); // assign special ids int space_id = alphabet.GetSpaceLabel(); @@ -48,8 +46,8 @@ std::vector> ctc_beam_search_decoder( } // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; + for (size_t time_step = 0; time_step < time_dim; ++time_step) { + auto *prob = &probs[time_step*class_dim]; float min_cutoff = -NUM_FLT_INF; bool full_beam = false; @@ -63,7 +61,7 @@ std::vector> ctc_beam_search_decoder( } std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n); // loop over chars for (size_t index = 0; index < log_prob_idx.size(); index++) { auto c = log_prob_idx[index].first; @@ -178,9 +176,14 @@ std::vector> ctc_beam_search_decoder( } -std::vector>> +std::vector> ctc_beam_search_decoder_batch( - const std::vector>> &probs_split, + const double *probs, + int batch_size, + int time_dim, + int class_dim, + const int* seq_lengths, + int seq_lengths_size, const Alphabet &alphabet, size_t beam_size, size_t num_processes, @@ -188,16 +191,17 @@ ctc_beam_search_decoder_batch( size_t cutoff_top_n, Scorer *ext_scorer) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per sequence"); // thread pool ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); // enqueue the tasks of decoding - std::vector>>> res; + std::vector>> res; for (size_t i = 0; i < batch_size; ++i) { res.emplace_back(pool.enqueue(ctc_beam_search_decoder, - probs_split[i], + &probs[i*time_dim*class_dim], + seq_lengths[i], + class_dim, alphabet, beam_size, cutoff_prob, @@ -206,7 +210,7 @@ ctc_beam_search_decoder_batch( } // get decoding results - std::vector>> batch_results; + std::vector> batch_results; for (size_t i = 0; i < batch_size; ++i) { batch_results.emplace_back(res[i].get()); } diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index 5fab441e..f2daa63e 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -2,7 +2,6 @@ #define CTC_BEAM_SEARCH_DECODER_H_ #include -#include #include #include "scorer.h" @@ -13,8 +12,8 @@ * Parameters: * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. + * over alphabet of one time step. + * alphabet: The alphabet. * beam_size: The width of beam search. * cutoff_prob: Cutoff probability for pruning. * cutoff_top_n: Cutoff number for pruning. @@ -26,20 +25,22 @@ * in desending order. */ -std::vector> ctc_beam_search_decoder( - const std::vector> &probs_seq, - const Alphabet &vocabulary, +std::vector ctc_beam_search_decoder( + const double* probs, + int time_dim, + int class_dim, + const Alphabet &alphabet, size_t beam_size, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer); /* CTC Beam Search Decoder for batch data * Parameters: * probs_seq: 3-D vector that each element is a 2-D vector that can be used * by ctc_beam_search_decoder(). - * vocabulary: A vector of vocabulary. + * alphabet: The alphabet. * beam_size: The width of beam search. * num_processes: Number of threads for beam search. * cutoff_prob: Cutoff probability for pruning. @@ -51,14 +52,19 @@ std::vector> ctc_beam_search_decoder( * A 2-D vector that each element is a vector of beam search decoding * result for one audio sample. */ -std::vector>> +std::vector> ctc_beam_search_decoder_batch( - const std::vector>> &probs_split, - const Alphabet &vocabulary, + const double* probs, + int batch_size, + int time_dim, + int class_dim, + const int* seq_lengths, + int seq_lengths_size, + const Alphabet &alphabet, size_t beam_size, size_t num_processes, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/native_client/ctcdecode/decoder_utils.cpp b/native_client/ctcdecode/decoder_utils.cpp index 375a06fe..95779bf5 100644 --- a/native_client/ctcdecode/decoder_utils.cpp +++ b/native_client/ctcdecode/decoder_utils.cpp @@ -5,15 +5,16 @@ #include std::vector> get_pruned_log_probs( - const std::vector &prob_step, + const double *prob_step, + size_t class_dim, double cutoff_prob, size_t cutoff_top_n) { std::vector> prob_idx; - for (size_t i = 0; i < prob_step.size(); ++i) { + for (size_t i = 0; i < class_dim; ++i) { prob_idx.push_back(std::pair(i, prob_step[i])); } // pruning of vacobulary - size_t cutoff_len = prob_step.size(); + size_t cutoff_len = class_dim; if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { std::sort( prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); @@ -38,7 +39,7 @@ std::vector> get_pruned_log_probs( } -std::vector> get_beam_search_result( +std::vector get_beam_search_result( const std::vector &prefixes, size_t beam_size) { // allow for the post processing @@ -50,17 +51,12 @@ std::vector> get_beam_search_result( } std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; + std::vector output_vecs; for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - std::vector timesteps; - space_prefixes[i]->get_path_vec(output, timesteps); - Output outputs; - outputs.tokens = output; - outputs.timesteps = timesteps; - std::pair output_pair(-space_prefixes[i]->approx_ctc, - outputs); - output_vecs.emplace_back(output_pair); + Output output; + space_prefixes[i]->get_path_vec(output.tokens, output.timesteps); + output.probability = -space_prefixes[i]->approx_ctc; + output_vecs.emplace_back(output); } return output_vecs; diff --git a/native_client/ctcdecode/decoder_utils.h b/native_client/ctcdecode/decoder_utils.h index 1bc75107..80689fa0 100644 --- a/native_client/ctcdecode/decoder_utils.h +++ b/native_client/ctcdecode/decoder_utils.h @@ -2,6 +2,8 @@ #define DECODER_UTILS_H_ #include +#include + #include "fst/log.h" #include "path_trie.h" #include "output.h" @@ -52,12 +54,13 @@ T log_sum_exp(const T &x, const T &y) { // Get pruned probability vector for each time step's beam search std::vector> get_pruned_log_probs( - const std::vector &prob_step, + const double *prob_step, + size_t class_dim, double cutoff_prob, size_t cutoff_top_n); // Get beam search result from prefixes in trie tree -std::vector> get_beam_search_result( +std::vector get_beam_search_result( const std::vector &prefixes, size_t beam_size); diff --git a/native_client/ctcdecode/output.h b/native_client/ctcdecode/output.h index a921ee2c..efa26518 100644 --- a/native_client/ctcdecode/output.h +++ b/native_client/ctcdecode/output.h @@ -1,11 +1,15 @@ #ifndef OUTPUT_H_ #define OUTPUT_H_ +#include + /* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps * for each token in the beam search output */ struct Output { - std::vector tokens, timesteps; + double probability; + std::vector tokens; + std::vector timesteps; }; #endif // OUTPUT_H_ diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index ff6499a7..2ca61b89 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -325,19 +325,15 @@ ModelState::decode(vector& logits) const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank const int n_frames = logits.size() / (BATCH_SIZE * num_classes); - vector> inputs; - inputs.resize(n_frames); - for (int t = 0; t < n_frames; ++t) { - for (int i = 0; i < num_classes; ++i) { - inputs[t].push_back(logits[t * num_classes + i]); - } - } + // Convert logits to double + vector inputs(logits.begin(), logits.end()); - // Vector of pairs - vector> out = ctc_beam_search_decoder( - inputs, *alphabet, beam_width, cutoff_prob, cutoff_top_n, scorer); + // Vector of pairs + vector out = ctc_beam_search_decoder( + inputs.data(), n_frames, num_classes, *alphabet, beam_width, + cutoff_prob, cutoff_top_n, scorer); - return strdup(alphabet->LabelsToString(out[0].second.tokens).c_str()); + return strdup(alphabet->LabelsToString(out[0].tokens).c_str()); } int