Merge pull request #2121 from dabinat/streaming-decoder

CTC streaming decoder
This commit is contained in:
dabinat 2019-05-21 21:41:48 -07:00 committed by GitHub
commit 69538f2f62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 220 additions and 91 deletions

View File

@ -14,49 +14,61 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<Output> ctc_beam_search_decoder( DecoderState* decoder_init(const Alphabet &alphabet,
const double *probs, int class_dim,
int time_dim, Scorer* ext_scorer) {
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
// dimension check // dimension check
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1, VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with " "The shape of probs does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
// assign special ids // assign special ids
int space_id = alphabet.GetSpaceLabel(); DecoderState *state = new DecoderState;
int blank_id = alphabet.GetSize(); state->space_id = alphabet.GetSpaceLabel();
state->blank_id = alphabet.GetSize();
// init prefixes' root // init prefixes' root
PathTrie root; PathTrie *root = new PathTrie;
root.score = root.log_prob_b_prev = 0.0; root->score = root->log_prob_b_prev = 0.0;
std::vector<PathTrie *> prefixes;
prefixes.push_back(&root); state->prefix_root = root;
state->prefixes.push_back(root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto dict_ptr = ext_scorer->dictionary->Copy(true); auto dict_ptr = ext_scorer->dictionary->Copy(true);
root.set_dictionary(dict_ptr); root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher); root->set_matcher(matcher);
} }
return state;
}
// prefix search over time void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer) {
// prefix search over time
for (size_t time_step = 0; time_step < time_dim; ++time_step) { for (size_t time_step = 0; time_step < time_dim; ++time_step) {
auto *prob = &probs[time_step*class_dim]; auto *prob = &probs[time_step*class_dim];
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); min_cutoff = state->prefixes[num_prefixes - 1]->score +
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size); full_beam = (num_prefixes == beam_size);
} }
@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second; auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i]; auto prefix = state->prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break; break;
} }
// blank // blank
if (c == blank_id) { if (c == state->blank_id) {
prefix->log_prob_b_cur = prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue; continue;
} }
// repeated character // repeated character
if (c == prefix->character) { if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c); auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(
// language model scoring // language model scoring
if (ext_scorer != nullptr && if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) { (c == state->space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr; PathTrie *prefix_to_score = nullptr;
// skip scoring the space // skip scoring the space
if (ext_scorer->is_character_based()) { if (ext_scorer->is_character_based()) {
@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
log_p += score; log_p += score;
log_p += ext_scorer->beta; log_p += ext_scorer->beta;
} }
prefix_new->log_prob_nb_cur = prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p); log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
} }
} // end of loop over prefix } // end of loop over prefix
} // end of loop over vocabulary } // end of loop over vocabulary
prefixes.clear();
// update log probs // update log probs
root.iterate_to_vec(prefixes); state->prefixes.clear();
state->prefix_root->iterate_to_vec(state->prefixes);
// only preserve top beam_size prefixes // only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) { if (state->prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(), std::nth_element(state->prefixes.begin(),
prefixes.begin() + beam_size, state->prefixes.begin() + beam_size,
prefixes.end(), state->prefixes.end(),
prefix_compare); prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) { for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
prefixes[i]->remove(); state->prefixes[i]->remove();
} }
} }
} // end of loop over time } // end of loop over time
}
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer) {
// score the last word of each prefix that doesn't end with space // score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
auto prefix = prefixes[i]; auto prefix = state->prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) { if (!prefix->is_empty() && prefix->character != state->space_id) {
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix); std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
@ -151,17 +173,17 @@ std::vector<Output> ctc_beam_search_decoder(
} }
} }
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable. // return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score; double approx_ctc = state->prefixes[i]->score;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
std::vector<int> output; std::vector<int> output;
std::vector<int> timesteps; std::vector<int> timesteps;
prefixes[i]->get_path_vec(output, timesteps); state->prefixes[i]->get_path_vec(output, timesteps);
auto prefix_length = output.size(); auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output); auto words = ext_scorer->split_labels(output);
// remove word insert // remove word insert
@ -169,12 +191,30 @@ std::vector<Output> ctc_beam_search_decoder(
// remove language model weight: // remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc; state->prefixes[i]->approx_ctc = approx_ctc;
} }
return get_beam_search_result(prefixes, beam_size); return get_beam_search_result(state->prefixes, beam_size);
} }
std::vector<Output> ctc_beam_search_decoder(
const double *probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer) {
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
delete state;
return out;
}
std::vector<std::vector<Output>> std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(

View File

@ -7,12 +7,73 @@
#include "scorer.h" #include "scorer.h"
#include "output.h" #include "output.h"
#include "alphabet.h" #include "alphabet.h"
#include "decoderstate.h"
/* CTC Beam Search Decoder /* Initialize CTC beam search decoder
* Parameters: * Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities * alphabet: The alphabet.
* over alphabet of one time step. * class_dim: Alphabet length (plus 1 for space character).
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A struct containing prefixes and state variables.
*/
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer *ext_scorer);
/* Send data to the decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* alphabet: The alphabet.
* state: The state structure previously obtained from decoder_init().
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
*/
void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer);
/* Get transcription for the data you sent via decoder_next()
* Parameters:
* state: The state structure previously obtained from decoder_init().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer);
/* CTC Beam Search Decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* alphabet: The alphabet. * alphabet: The alphabet.
* beam_size: The width of beam search. * beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning. * cutoff_prob: Cutoff probability for pruning.
@ -21,8 +82,8 @@
* n-gram language model scoring and word insertion term. * n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer. * Default null, decoding the input sample without scorer.
* Return: * Return:
* A vector that each element is a pair of score and decoding result, * A vector where each element is a pair of score and decoding result,
* in desending order. * in descending order.
*/ */
std::vector<Output> ctc_beam_search_decoder( std::vector<Output> ctc_beam_search_decoder(
@ -36,9 +97,8 @@ std::vector<Output> ctc_beam_search_decoder(
Scorer *ext_scorer); Scorer *ext_scorer);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
* Parameters: * Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used * probs: 3-D vector where each element is a 2-D vector that can be used
* by ctc_beam_search_decoder(). * by ctc_beam_search_decoder().
* alphabet: The alphabet. * alphabet: The alphabet.
* beam_size: The width of beam search. * beam_size: The width of beam search.
@ -49,7 +109,7 @@ std::vector<Output> ctc_beam_search_decoder(
* n-gram language model scoring and word insertion term. * n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer. * Default null, decoding the input sample without scorer.
* Return: * Return:
* A 2-D vector that each element is a vector of beam search decoding * A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample. * result for one audio sample.
*/ */
std::vector<std::vector<Output>> std::vector<std::vector<Output>>

View File

@ -0,0 +1,22 @@
#ifndef DECODERSTATE_H_
#define DECODERSTATE_H_
#include <vector>
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
struct DecoderState {
int space_id;
int blank_id;
std::vector<PathTrie*> prefixes;
PathTrie *prefix_root;
~DecoderState() {
if (prefix_root != nullptr) {
delete prefix_root;
}
prefix_root = nullptr;
}
};
#endif // DECODERSTATE_H_

View File

@ -75,13 +75,12 @@ using std::vector;
API. When audio_buffer is full, features are computed from it and pushed to API. When audio_buffer is full, features are computed from it and pushed to
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer. mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
When batch_buffer is full, we do a single step through the acoustic model When batch_buffer is full, we do a single step through the acoustic model
and accumulate results in StreamingState::accumulated_logits. and accumulate results in the DecoderState structure.
When fininshStream() is called, we decode the accumulated logits and return When finishStream() is called, we decode the accumulated logits and return
the corresponding transcription. the corresponding transcription.
*/ */
struct StreamingState { struct StreamingState {
vector<float> accumulated_logits;
vector<float> audio_buffer; vector<float> audio_buffer;
vector<float> mfcc_buffer; vector<float> mfcc_buffer;
vector<float> batch_buffer; vector<float> batch_buffer;
@ -113,6 +112,7 @@ struct ModelState {
unsigned int ncontext; unsigned int ncontext;
Alphabet* alphabet; Alphabet* alphabet;
Scorer* scorer; Scorer* scorer;
DecoderState* decoder_state;
unsigned int beam_width; unsigned int beam_width;
unsigned int n_steps; unsigned int n_steps;
unsigned int n_context; unsigned int n_context;
@ -145,34 +145,26 @@ struct ModelState {
* @brief Perform decoding of the logits, using basic CTC decoder or * @brief Perform decoding of the logits, using basic CTC decoder or
* CTC decoder with KenLM enabled * CTC decoder with KenLM enabled
* *
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return String representing the decoded text. * @return String representing the decoded text.
*/ */
char* decode(const vector<float>& logits); char* decode();
/** /**
* @brief Perform decoding of the logits, using basic CTC decoder or * @brief Perform decoding of the logits, using basic CTC decoder or
* CTC decoder with KenLM enabled * CTC decoder with KenLM enabled
* *
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return Vector of Output structs directly from the CTC decoder for additional processing. * @return Vector of Output structs directly from the CTC decoder for additional processing.
*/ */
vector<Output> decode_raw(const vector<float>& logits); vector<Output> decode_raw();
/** /**
* @brief Return character-level metadata including letter timings. * @brief Return character-level metadata including letter timings.
* *
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
* *
* @return Metadata struct containing MetadataItem structs for each character. * @return Metadata struct containing MetadataItem structs for each character.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata(). * The user is responsible for freeing Metadata by calling DS_FreeMetadata().
*/ */
Metadata* decode_metadata(const vector<float>& logits); Metadata* decode_metadata();
/** /**
* @brief Do a single inference step in the acoustic model, with: * @brief Do a single inference step in the acoustic model, with:
@ -202,6 +194,7 @@ ModelState::ModelState()
, ncontext(0) , ncontext(0)
, alphabet(nullptr) , alphabet(nullptr)
, scorer(nullptr) , scorer(nullptr)
, decoder_state(nullptr)
, beam_width(0) , beam_width(0)
, n_steps(-1) , n_steps(-1)
, n_context(-1) , n_context(-1)
@ -232,6 +225,11 @@ ModelState::~ModelState()
delete scorer; delete scorer;
delete alphabet; delete alphabet;
if (decoder_state != nullptr) {
delete decoder_state;
decoder_state = nullptr;
}
} }
template<typename T> template<typename T>
@ -270,21 +268,21 @@ StreamingState::feedAudioContent(const short* buffer,
char* char*
StreamingState::intermediateDecode() StreamingState::intermediateDecode()
{ {
return model->decode(accumulated_logits); return model->decode();
} }
char* char*
StreamingState::finishStream() StreamingState::finishStream()
{ {
finalizeStream(); finalizeStream();
return model->decode(accumulated_logits); return model->decode();
} }
Metadata* Metadata*
StreamingState::finishStreamWithMetadata() StreamingState::finishStreamWithMetadata()
{ {
finalizeStream(); finalizeStream();
return model->decode_metadata(accumulated_logits); return model->decode_metadata();
} }
void void
@ -372,7 +370,26 @@ StreamingState::processMfccWindow(const vector<float>& buf)
void void
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps) StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
{ {
model->infer(buf.data(), n_steps, accumulated_logits); vector<float> logits;
model->infer(buf.data(), n_steps, logits);
const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;
const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
// Convert logits to double
vector<double> inputs(logits.begin(), logits.end());
decoder_next(inputs.data(),
*model->alphabet,
model->decoder_state,
n_frames,
num_classes,
cutoff_prob,
cutoff_top_n,
model->beam_width,
model->scorer);
} }
void void
@ -507,35 +524,24 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
} }
char* char*
ModelState::decode(const vector<float>& logits) ModelState::decode()
{ {
vector<Output> out = ModelState::decode_raw(logits); vector<Output> out = ModelState::decode_raw();
return strdup(alphabet->LabelsToString(out[0].tokens).c_str()); return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
} }
vector<Output> vector<Output>
ModelState::decode_raw(const vector<float>& logits) ModelState::decode_raw()
{ {
const int cutoff_top_n = 40; vector<Output> out = decoder_decode(decoder_state, *alphabet, beam_width, scorer);
const double cutoff_prob = 1.0;
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
// Convert logits to double
vector<double> inputs(logits.begin(), logits.end());
// Vector of <probability, Output> pairs
vector<Output> out = ctc_beam_search_decoder(
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
cutoff_prob, cutoff_top_n, scorer);
return out; return out;
} }
Metadata* Metadata*
ModelState::decode_metadata(const vector<float>& logits) ModelState::decode_metadata()
{ {
vector<Output> out = decode_raw(logits); vector<Output> out = decode_raw();
std::unique_ptr<Metadata> metadata(new Metadata()); std::unique_ptr<Metadata> metadata(new Metadata());
metadata->num_items = out[0].tokens.size(); metadata->num_items = out[0].tokens.size();
@ -830,8 +836,6 @@ DS_SetupStream(ModelState* aCtx,
aPreAllocFrames = 150; aPreAllocFrames = 150;
} }
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
ctx->audio_buffer.reserve(aCtx->audio_win_len); ctx->audio_buffer.reserve(aCtx->audio_win_len);
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep); ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f); ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
@ -839,6 +843,9 @@ DS_SetupStream(ModelState* aCtx,
ctx->model = aCtx; ctx->model = aCtx;
DecoderState *params = decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer);
aCtx->decoder_state = params;
*retval = ctx.release(); *retval = ctx.release();
return DS_ERR_OK; return DS_ERR_OK;
} }