Simplify ctcdecode API signatures to avoid nested STL structures

Flatten structure to avoid nested STL structures which are awkward to
wrap with SWIG and slower at runtime.
This commit is contained in:
Reuben Morais 2018-11-02 13:43:16 -03:00
parent bb4551caa9
commit 440893c58d
6 changed files with 72 additions and 63 deletions

View File

@ -14,21 +14,19 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
alphabet.GetSize()+1,
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with "
"the shape of the vocabulary");
// assign special ids
int space_id = alphabet.GetSpaceLabel();
@ -48,8 +46,8 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
}
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto *prob = &probs[time_step*class_dim];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
@ -63,7 +61,7 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
@ -178,9 +176,14 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
}
std::vector<std::vector<std::pair<double, Output>>>
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const double *probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
@ -188,16 +191,17 @@ ctc_beam_search_decoder_batch(
size_t cutoff_top_n,
Scorer *ext_scorer) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per sequence");
// thread pool
ThreadPool pool(num_processes);
// number of samples
size_t batch_size = probs_split.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, Output>>>> res;
std::vector<std::future<std::vector<Output>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i],
&probs[i*time_dim*class_dim],
seq_lengths[i],
class_dim,
alphabet,
beam_size,
cutoff_prob,
@ -206,7 +210,7 @@ ctc_beam_search_decoder_batch(
}
// get decoding results
std::vector<std::vector<std::pair<double, Output>>> batch_results;
std::vector<std::vector<Output>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}

View File

@ -2,7 +2,6 @@
#define CTC_BEAM_SEARCH_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "scorer.h"
@ -13,8 +12,8 @@
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* over alphabet of one time step.
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
@ -26,20 +25,22 @@
* in desending order.
*/
std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &probs_seq,
const Alphabet &vocabulary,
std::vector<Output> ctc_beam_search_decoder(
const double* probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* vocabulary: A vector of vocabulary.
* alphabet: The alphabet.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
@ -51,14 +52,19 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, Output>>>
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const Alphabet &vocabulary,
const double* probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
double cutoff_prob = 1.0,
size_t cutoff_top_n = 40,
Scorer *ext_scorer = nullptr);
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer);
#endif // CTC_BEAM_SEARCH_DECODER_H_

View File

@ -5,15 +5,16 @@
#include <limits>
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
const double *prob_step,
size_t class_dim,
double cutoff_prob,
size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) {
for (size_t i = 0; i < class_dim; ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
}
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
size_t cutoff_len = class_dim;
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
@ -38,7 +39,7 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
}
std::vector<std::pair<double, Output>> get_beam_search_result(
std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
size_t beam_size) {
// allow for the post processing
@ -50,17 +51,12 @@ std::vector<std::pair<double, Output>> get_beam_search_result(
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, Output>> output_vecs;
std::vector<Output> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output;
std::vector<int> timesteps;
space_prefixes[i]->get_path_vec(output, timesteps);
Output outputs;
outputs.tokens = output;
outputs.timesteps = timesteps;
std::pair<double, Output> output_pair(-space_prefixes[i]->approx_ctc,
outputs);
output_vecs.emplace_back(output_pair);
Output output;
space_prefixes[i]->get_path_vec(output.tokens, output.timesteps);
output.probability = -space_prefixes[i]->approx_ctc;
output_vecs.emplace_back(output);
}
return output_vecs;

View File

@ -2,6 +2,8 @@
#define DECODER_UTILS_H_
#include <utility>
#include <vector>
#include "fst/log.h"
#include "path_trie.h"
#include "output.h"
@ -52,12 +54,13 @@ T log_sum_exp(const T &x, const T &y) {
// Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step,
const double *prob_step,
size_t class_dim,
double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<std::pair<double, Output>> get_beam_search_result(
std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
size_t beam_size);

View File

@ -1,11 +1,15 @@
#ifndef OUTPUT_H_
#define OUTPUT_H_
#include <vector>
/* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps
* for each token in the beam search output
*/
struct Output {
std::vector<int> tokens, timesteps;
double probability;
std::vector<int> tokens;
std::vector<int> timesteps;
};
#endif // OUTPUT_H_

View File

@ -325,19 +325,15 @@ ModelState::decode(vector<float>& logits)
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
vector<vector<double>> inputs;
inputs.resize(n_frames);
for (int t = 0; t < n_frames; ++t) {
for (int i = 0; i < num_classes; ++i) {
inputs[t].push_back(logits[t * num_classes + i]);
}
}
// Convert logits to double
vector<double> inputs(logits.begin(), logits.end());
// Vector of <probability, Output(tokens, timings)> pairs
vector<std::pair<double, Output>> out = ctc_beam_search_decoder(
inputs, *alphabet, beam_width, cutoff_prob, cutoff_top_n, scorer);
// Vector of <probability, Output> pairs
vector<Output> out = ctc_beam_search_decoder(
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
cutoff_prob, cutoff_top_n, scorer);
return strdup(alphabet->LabelsToString(out[0].second.tokens).c_str());
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
}
int