Simplify ctcdecode API signatures to avoid nested STL structures

Flatten structure to avoid nested STL structures which are awkward to
wrap with SWIG and slower at runtime.
This commit is contained in:
Reuben Morais 2018-11-02 13:43:16 -03:00
parent bb4551caa9
commit 440893c58d
6 changed files with 72 additions and 63 deletions

View File

@ -14,21 +14,19 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, Output>> ctc_beam_search_decoder( std::vector<Output> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet, const Alphabet &alphabet,
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
for (size_t i = 0; i < num_time_steps; ++i) { "The shape of probs does not match with "
VALID_CHECK_EQ(probs_seq[i].size(), "the shape of the vocabulary");
alphabet.GetSize()+1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign special ids // assign special ids
int space_id = alphabet.GetSpaceLabel(); int space_id = alphabet.GetSpaceLabel();
@ -48,8 +46,8 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
} }
// prefix search over time // prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto &prob = probs_seq[time_step]; auto *prob = &probs[time_step*class_dim];
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
@ -63,7 +61,7 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
} }
std::vector<std::pair<size_t, float>> log_prob_idx = std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
// loop over chars // loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) { for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
@ -178,9 +176,14 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
} }
std::vector<std::vector<std::pair<double, Output>>> std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const double *probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet, const Alphabet &alphabet,
size_t beam_size, size_t beam_size,
size_t num_processes, size_t num_processes,
@ -188,16 +191,17 @@ ctc_beam_search_decoder_batch(
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) { Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per sequence");
// thread pool // thread pool
ThreadPool pool(num_processes); ThreadPool pool(num_processes);
// number of samples
size_t batch_size = probs_split.size();
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, Output>>>> res; std::vector<std::future<std::vector<Output>>> res;
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i], &probs[i*time_dim*class_dim],
seq_lengths[i],
class_dim,
alphabet, alphabet,
beam_size, beam_size,
cutoff_prob, cutoff_prob,
@ -206,7 +210,7 @@ ctc_beam_search_decoder_batch(
} }
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, Output>>> batch_results; std::vector<std::vector<Output>> batch_results;
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }

View File

@ -2,7 +2,6 @@
#define CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "scorer.h" #include "scorer.h"
@ -13,8 +12,8 @@
* Parameters: * Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities * probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step. * over alphabet of one time step.
* vocabulary: A vector of vocabulary. * alphabet: The alphabet.
* beam_size: The width of beam search. * beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning. * cutoff_top_n: Cutoff number for pruning.
@ -26,20 +25,22 @@
* in desending order. * in desending order.
*/ */
std::vector<std::pair<double, Output>> ctc_beam_search_decoder( std::vector<Output> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq, const double* probs,
const Alphabet &vocabulary, int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size, size_t beam_size,
double cutoff_prob = 1.0, double cutoff_prob,
size_t cutoff_top_n = 40, size_t cutoff_top_n,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
* Parameters: * Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used * probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder(). * by ctc_beam_search_decoder().
* vocabulary: A vector of vocabulary. * alphabet: The alphabet.
* beam_size: The width of beam search. * beam_size: The width of beam search.
* num_processes: Number of threads for beam search. * num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
@ -51,14 +52,19 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
* A 2-D vector that each element is a vector of beam search decoding * A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample. * result for one audio sample.
*/ */
std::vector<std::vector<std::pair<double, Output>>> std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const double* probs,
const Alphabet &vocabulary, int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size, size_t beam_size,
size_t num_processes, size_t num_processes,
double cutoff_prob = 1.0, double cutoff_prob,
size_t cutoff_top_n = 40, size_t cutoff_top_n,
Scorer *ext_scorer = nullptr); Scorer *ext_scorer);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

View File

@ -5,15 +5,16 @@
#include <limits> #include <limits>
std::vector<std::pair<size_t, float>> get_pruned_log_probs( std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step, const double *prob_step,
size_t class_dim,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n) { size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx; std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) { for (size_t i = 0; i < class_dim; ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i])); prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
} }
// pruning of vacobulary // pruning of vacobulary
size_t cutoff_len = prob_step.size(); size_t cutoff_len = class_dim;
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort( std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
@ -38,7 +39,7 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
} }
std::vector<std::pair<double, Output>> get_beam_search_result( std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, const std::vector<PathTrie *> &prefixes,
size_t beam_size) { size_t beam_size) {
// allow for the post processing // allow for the post processing
@ -50,17 +51,12 @@ std::vector<std::pair<double, Output>> get_beam_search_result(
} }
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, Output>> output_vecs; std::vector<Output> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output; Output output;
std::vector<int> timesteps; space_prefixes[i]->get_path_vec(output.tokens, output.timesteps);
space_prefixes[i]->get_path_vec(output, timesteps); output.probability = -space_prefixes[i]->approx_ctc;
Output outputs; output_vecs.emplace_back(output);
outputs.tokens = output;
outputs.timesteps = timesteps;
std::pair<double, Output> output_pair(-space_prefixes[i]->approx_ctc,
outputs);
output_vecs.emplace_back(output_pair);
} }
return output_vecs; return output_vecs;

View File

@ -2,6 +2,8 @@
#define DECODER_UTILS_H_ #define DECODER_UTILS_H_
#include <utility> #include <utility>
#include <vector>
#include "fst/log.h" #include "fst/log.h"
#include "path_trie.h" #include "path_trie.h"
#include "output.h" #include "output.h"
@ -52,12 +54,13 @@ T log_sum_exp(const T &x, const T &y) {
// Get pruned probability vector for each time step's beam search // Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs( std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step, const double *prob_step,
size_t class_dim,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n); size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree // Get beam search result from prefixes in trie tree
std::vector<std::pair<double, Output>> get_beam_search_result( std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, const std::vector<PathTrie *> &prefixes,
size_t beam_size); size_t beam_size);

View File

@ -1,11 +1,15 @@
#ifndef OUTPUT_H_ #ifndef OUTPUT_H_
#define OUTPUT_H_ #define OUTPUT_H_
#include <vector>
/* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps /* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps
* for each token in the beam search output * for each token in the beam search output
*/ */
struct Output { struct Output {
std::vector<int> tokens, timesteps; double probability;
std::vector<int> tokens;
std::vector<int> timesteps;
}; };
#endif // OUTPUT_H_ #endif // OUTPUT_H_

View File

@ -325,19 +325,15 @@ ModelState::decode(vector<float>& logits)
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
const int n_frames = logits.size() / (BATCH_SIZE * num_classes); const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
vector<vector<double>> inputs; // Convert logits to double
inputs.resize(n_frames); vector<double> inputs(logits.begin(), logits.end());
for (int t = 0; t < n_frames; ++t) {
for (int i = 0; i < num_classes; ++i) {
inputs[t].push_back(logits[t * num_classes + i]);
}
}
// Vector of <probability, Output(tokens, timings)> pairs // Vector of <probability, Output> pairs
vector<std::pair<double, Output>> out = ctc_beam_search_decoder( vector<Output> out = ctc_beam_search_decoder(
inputs, *alphabet, beam_width, cutoff_prob, cutoff_top_n, scorer); inputs.data(), n_frames, num_classes, *alphabet, beam_width,
cutoff_prob, cutoff_top_n, scorer);
return strdup(alphabet->LabelsToString(out[0].second.tokens).c_str()); return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
} }
int int