Merge pull request #2121 from dabinat/streaming-decoder

CTC streaming decoder
This commit is contained in:
dabinat 2019-05-21 21:41:48 -07:00 committed by GitHub
commit 69538f2f62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 220 additions and 91 deletions

View File

@ -14,49 +14,61 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer* ext_scorer) {
// dimension check
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with "
"the shape of the vocabulary");
// assign special ids
int space_id = alphabet.GetSpaceLabel();
int blank_id = alphabet.GetSize();
DecoderState *state = new DecoderState;
state->space_id = alphabet.GetSpaceLabel();
state->blank_id = alphabet.GetSize();
// init prefixes' root
PathTrie root;
root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root);
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;
state->prefix_root = root;
state->prefixes.push_back(root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto dict_ptr = ext_scorer->dictionary->Copy(true);
root.set_dictionary(dict_ptr);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
root->set_matcher(matcher);
}
return state;
}
// prefix search over time
void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer) {
// prefix search over time
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto *prob = &probs[time_step*class_dim];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = state->prefixes[num_prefixes - 1]->score +
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
auto prefix = state->prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
if (c == state->blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
(c == state->space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
if (state->prefixes.size() >= beam_size) {
std::nth_element(state->prefixes.begin(),
state->prefixes.begin() + beam_size,
state->prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
state->prefixes[i]->remove();
}
}
} // end of loop over time
}
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer) {
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
auto prefix = state->prefixes[i];
if (!prefix->is_empty() && prefix->character != state->space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
@ -151,17 +173,17 @@ std::vector<Output> ctc_beam_search_decoder(
}
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
double approx_ctc = state->prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
std::vector<int> timesteps;
prefixes[i]->get_path_vec(output, timesteps);
state->prefixes[i]->get_path_vec(output, timesteps);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
@ -169,12 +191,30 @@ std::vector<Output> ctc_beam_search_decoder(
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
state->prefixes[i]->approx_ctc = approx_ctc;
}
return get_beam_search_result(prefixes, beam_size);
return get_beam_search_result(state->prefixes, beam_size);
}
std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
delete state;
return out;
}
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(

View File

@ -7,12 +7,73 @@
#include "scorer.h"
#include "output.h"
#include "alphabet.h"
#include "decoderstate.h"
/* CTC Beam Search Decoder
/* Initialize CTC beam search decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over alphabet of one time step.
* alphabet: The alphabet.
* class_dim: Alphabet length (plus 1 for space character).
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A struct containing prefixes and state variables.
*/
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer *ext_scorer);
/* Send data to the decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* alphabet: The alphabet.
* state: The state structure previously obtained from decoder_init().
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
*/
void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer);
/* Get transcription for the data you sent via decoder_next()
* Parameters:
* state: The state structure previously obtained from decoder_init().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer);
/* CTC Beam Search Decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
@ -21,8 +82,8 @@
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> ctc_beam_search_decoder(
@ -36,9 +97,8 @@ std::vector<Output> ctc_beam_search_decoder(
Scorer *ext_scorer);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* probs: 3-D vector where each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* alphabet: The alphabet.
* beam_size: The width of beam search.
@ -49,7 +109,7 @@ std::vector<Output> ctc_beam_search_decoder(
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<Output>>

View File

@ -0,0 +1,22 @@
#ifndef DECODERSTATE_H_
#define DECODERSTATE_H_
#include <vector>
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
struct DecoderState {
int space_id;
int blank_id;
std::vector<PathTrie*> prefixes;
PathTrie *prefix_root;
~DecoderState() {
if (prefix_root != nullptr) {
delete prefix_root;
}
prefix_root = nullptr;
}
};
#endif // DECODERSTATE_H_

View File

@ -75,13 +75,12 @@ using std::vector;
API. When audio_buffer is full, features are computed from it and pushed to
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
When batch_buffer is full, we do a single step through the acoustic model
and accumulate results in StreamingState::accumulated_logits.
and accumulate results in the DecoderState structure.
When fininshStream() is called, we decode the accumulated logits and return
When finishStream() is called, we decode the accumulated logits and return
the corresponding transcription.
*/
struct StreamingState {
vector<float> accumulated_logits;
vector<float> audio_buffer;
vector<float> mfcc_buffer;
vector<float> batch_buffer;
@ -113,6 +112,7 @@ struct ModelState {
unsigned int ncontext;
Alphabet* alphabet;
Scorer* scorer;
DecoderState* decoder_state;
unsigned int beam_width;
unsigned int n_steps;
unsigned int n_context;
@ -145,34 +145,26 @@ struct ModelState {
* @brief Perform decoding of the logits, using basic CTC decoder or
* CTC decoder with KenLM enabled
*
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return String representing the decoded text.
*/
char* decode(const vector<float>& logits);
char* decode();
/**
* @brief Perform decoding of the logits, using basic CTC decoder or
* CTC decoder with KenLM enabled
*
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return Vector of Output structs directly from the CTC decoder for additional processing.
*/
vector<Output> decode_raw(const vector<float>& logits);
vector<Output> decode_raw();
/**
* @brief Return character-level metadata including letter timings.
*
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return Metadata struct containing MetadataItem structs for each character.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
*/
Metadata* decode_metadata(const vector<float>& logits);
Metadata* decode_metadata();
/**
* @brief Do a single inference step in the acoustic model, with:
@ -202,6 +194,7 @@ ModelState::ModelState()
, ncontext(0)
, alphabet(nullptr)
, scorer(nullptr)
, decoder_state(nullptr)
, beam_width(0)
, n_steps(-1)
, n_context(-1)
@ -232,6 +225,11 @@ ModelState::~ModelState()
delete scorer;
delete alphabet;
if (decoder_state != nullptr) {
delete decoder_state;
decoder_state = nullptr;
}
}
template<typename T>
@ -270,21 +268,21 @@ StreamingState::feedAudioContent(const short* buffer,
char*
StreamingState::intermediateDecode()
{
return model->decode(accumulated_logits);
return model->decode();
}
char*
StreamingState::finishStream()
{
finalizeStream();
return model->decode(accumulated_logits);
return model->decode();
}
Metadata*
StreamingState::finishStreamWithMetadata()
{
finalizeStream();
return model->decode_metadata(accumulated_logits);
return model->decode_metadata();
}
void
@ -372,7 +370,26 @@ StreamingState::processMfccWindow(const vector<float>& buf)
void
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
{
model->infer(buf.data(), n_steps, accumulated_logits);
vector<float> logits;
model->infer(buf.data(), n_steps, logits);
const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;
const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
// Convert logits to double
vector<double> inputs(logits.begin(), logits.end());
decoder_next(inputs.data(),
*model->alphabet,
model->decoder_state,
n_frames,
num_classes,
cutoff_prob,
cutoff_top_n,
model->beam_width,
model->scorer);
}
void
@ -507,35 +524,24 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
}
char*
ModelState::decode(const vector<float>& logits)
ModelState::decode()
{
vector<Output> out = ModelState::decode_raw(logits);
vector<Output> out = ModelState::decode_raw();
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
}
vector<Output>
ModelState::decode_raw(const vector<float>& logits)
ModelState::decode_raw()
{
const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
// Convert logits to double
vector<double> inputs(logits.begin(), logits.end());
// Vector of <probability, Output> pairs
vector<Output> out = ctc_beam_search_decoder(
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
cutoff_prob, cutoff_top_n, scorer);
vector<Output> out = decoder_decode(decoder_state, *alphabet, beam_width, scorer);
return out;
}
Metadata*
ModelState::decode_metadata(const vector<float>& logits)
ModelState::decode_metadata()
{
vector<Output> out = decode_raw(logits);
vector<Output> out = decode_raw();
std::unique_ptr<Metadata> metadata(new Metadata());
metadata->num_items = out[0].tokens.size();
@ -830,8 +836,6 @@ DS_SetupStream(ModelState* aCtx,
aPreAllocFrames = 150;
}
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
ctx->audio_buffer.reserve(aCtx->audio_win_len);
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
@ -839,6 +843,9 @@ DS_SetupStream(ModelState* aCtx,
ctx->model = aCtx;
DecoderState *params = decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer);
aCtx->decoder_state = params;
*retval = ctx.release();
return DS_ERR_OK;
}