Simplify ctcdecode API signatures to avoid nested STL structures
Flatten structure to avoid nested STL structures which are awkward to wrap with SWIG and slower at runtime.
This commit is contained in:
parent
bb4551caa9
commit
440893c58d
@ -14,21 +14,19 @@
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
std::vector<Output> ctc_beam_search_decoder(
|
||||
const double *probs,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
alphabet.GetSize()+1,
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
||||
"The shape of probs does not match with "
|
||||
"the shape of the vocabulary");
|
||||
|
||||
// assign special ids
|
||||
int space_id = alphabet.GetSpaceLabel();
|
||||
@ -48,8 +46,8 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
|
||||
}
|
||||
|
||||
// prefix search over time
|
||||
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||
auto &prob = probs_seq[time_step];
|
||||
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
|
||||
auto *prob = &probs[time_step*class_dim];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
@ -63,7 +61,7 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
|
||||
// loop over chars
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
@ -178,9 +176,14 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<std::pair<double, Output>>>
|
||||
std::vector<std::vector<Output>>
|
||||
ctc_beam_search_decoder_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const double *probs,
|
||||
int batch_size,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
const int* seq_lengths,
|
||||
int seq_lengths_size,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
@ -188,16 +191,17 @@ ctc_beam_search_decoder_batch(
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per sequence");
|
||||
// thread pool
|
||||
ThreadPool pool(num_processes);
|
||||
// number of samples
|
||||
size_t batch_size = probs_split.size();
|
||||
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<std::vector<std::pair<double, Output>>>> res;
|
||||
std::vector<std::future<std::vector<Output>>> res;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
||||
probs_split[i],
|
||||
&probs[i*time_dim*class_dim],
|
||||
seq_lengths[i],
|
||||
class_dim,
|
||||
alphabet,
|
||||
beam_size,
|
||||
cutoff_prob,
|
||||
@ -206,7 +210,7 @@ ctc_beam_search_decoder_batch(
|
||||
}
|
||||
|
||||
// get decoding results
|
||||
std::vector<std::vector<std::pair<double, Output>>> batch_results;
|
||||
std::vector<std::vector<Output>> batch_results;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch_results.emplace_back(res[i].get());
|
||||
}
|
||||
|
@ -2,7 +2,6 @@
|
||||
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "scorer.h"
|
||||
@ -13,8 +12,8 @@
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* over alphabet of one time step.
|
||||
* alphabet: The alphabet.
|
||||
* beam_size: The width of beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
@ -26,20 +25,22 @@
|
||||
* in desending order.
|
||||
*/
|
||||
|
||||
std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const Alphabet &vocabulary,
|
||||
std::vector<Output> ctc_beam_search_decoder(
|
||||
const double* probs,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr);
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer);
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
|
||||
* by ctc_beam_search_decoder().
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* alphabet: The alphabet.
|
||||
* beam_size: The width of beam search.
|
||||
* num_processes: Number of threads for beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
@ -51,14 +52,19 @@ std::vector<std::pair<double, Output>> ctc_beam_search_decoder(
|
||||
* A 2-D vector that each element is a vector of beam search decoding
|
||||
* result for one audio sample.
|
||||
*/
|
||||
std::vector<std::vector<std::pair<double, Output>>>
|
||||
std::vector<std::vector<Output>>
|
||||
ctc_beam_search_decoder_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const Alphabet &vocabulary,
|
||||
const double* probs,
|
||||
int batch_size,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
const int* seq_lengths,
|
||||
int seq_lengths_size,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr);
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
@ -5,15 +5,16 @@
|
||||
#include <limits>
|
||||
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
const double *prob_step,
|
||||
size_t class_dim,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n) {
|
||||
std::vector<std::pair<int, double>> prob_idx;
|
||||
for (size_t i = 0; i < prob_step.size(); ++i) {
|
||||
for (size_t i = 0; i < class_dim; ++i) {
|
||||
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
|
||||
}
|
||||
// pruning of vacobulary
|
||||
size_t cutoff_len = prob_step.size();
|
||||
size_t cutoff_len = class_dim;
|
||||
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
|
||||
std::sort(
|
||||
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||
@ -38,7 +39,7 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::pair<double, Output>> get_beam_search_result(
|
||||
std::vector<Output> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
size_t beam_size) {
|
||||
// allow for the post processing
|
||||
@ -50,17 +51,12 @@ std::vector<std::pair<double, Output>> get_beam_search_result(
|
||||
}
|
||||
|
||||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
||||
std::vector<std::pair<double, Output>> output_vecs;
|
||||
std::vector<Output> output_vecs;
|
||||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
||||
std::vector<int> output;
|
||||
std::vector<int> timesteps;
|
||||
space_prefixes[i]->get_path_vec(output, timesteps);
|
||||
Output outputs;
|
||||
outputs.tokens = output;
|
||||
outputs.timesteps = timesteps;
|
||||
std::pair<double, Output> output_pair(-space_prefixes[i]->approx_ctc,
|
||||
outputs);
|
||||
output_vecs.emplace_back(output_pair);
|
||||
Output output;
|
||||
space_prefixes[i]->get_path_vec(output.tokens, output.timesteps);
|
||||
output.probability = -space_prefixes[i]->approx_ctc;
|
||||
output_vecs.emplace_back(output);
|
||||
}
|
||||
|
||||
return output_vecs;
|
||||
|
@ -2,6 +2,8 @@
|
||||
#define DECODER_UTILS_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/log.h"
|
||||
#include "path_trie.h"
|
||||
#include "output.h"
|
||||
@ -52,12 +54,13 @@ T log_sum_exp(const T &x, const T &y) {
|
||||
|
||||
// Get pruned probability vector for each time step's beam search
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
const double *prob_step,
|
||||
size_t class_dim,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n);
|
||||
|
||||
// Get beam search result from prefixes in trie tree
|
||||
std::vector<std::pair<double, Output>> get_beam_search_result(
|
||||
std::vector<Output> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
size_t beam_size);
|
||||
|
||||
|
@ -1,11 +1,15 @@
|
||||
#ifndef OUTPUT_H_
|
||||
#define OUTPUT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
/* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps
|
||||
* for each token in the beam search output
|
||||
*/
|
||||
struct Output {
|
||||
std::vector<int> tokens, timesteps;
|
||||
double probability;
|
||||
std::vector<int> tokens;
|
||||
std::vector<int> timesteps;
|
||||
};
|
||||
|
||||
#endif // OUTPUT_H_
|
||||
|
@ -325,19 +325,15 @@ ModelState::decode(vector<float>& logits)
|
||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
|
||||
|
||||
vector<vector<double>> inputs;
|
||||
inputs.resize(n_frames);
|
||||
for (int t = 0; t < n_frames; ++t) {
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
inputs[t].push_back(logits[t * num_classes + i]);
|
||||
}
|
||||
}
|
||||
// Convert logits to double
|
||||
vector<double> inputs(logits.begin(), logits.end());
|
||||
|
||||
// Vector of <probability, Output(tokens, timings)> pairs
|
||||
vector<std::pair<double, Output>> out = ctc_beam_search_decoder(
|
||||
inputs, *alphabet, beam_width, cutoff_prob, cutoff_top_n, scorer);
|
||||
// Vector of <probability, Output> pairs
|
||||
vector<Output> out = ctc_beam_search_decoder(
|
||||
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
|
||||
cutoff_prob, cutoff_top_n, scorer);
|
||||
|
||||
return strdup(alphabet->LabelsToString(out[0].second.tokens).c_str());
|
||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
int
|
||||
|
Loading…
Reference in New Issue
Block a user