Merge pull request #2121 from dabinat/streaming-decoder
CTC streaming decoder
This commit is contained in:
commit
69538f2f62
@ -14,49 +14,61 @@
|
|||||||
|
|
||||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
|
||||||
std::vector<Output> ctc_beam_search_decoder(
|
DecoderState* decoder_init(const Alphabet &alphabet,
|
||||||
const double *probs,
|
int class_dim,
|
||||||
int time_dim,
|
Scorer* ext_scorer) {
|
||||||
int class_dim,
|
|
||||||
const Alphabet &alphabet,
|
|
||||||
size_t beam_size,
|
|
||||||
double cutoff_prob,
|
|
||||||
size_t cutoff_top_n,
|
|
||||||
Scorer *ext_scorer) {
|
|
||||||
// dimension check
|
// dimension check
|
||||||
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
||||||
"The shape of probs does not match with "
|
"The shape of probs does not match with "
|
||||||
"the shape of the vocabulary");
|
"the shape of the vocabulary");
|
||||||
|
|
||||||
// assign special ids
|
// assign special ids
|
||||||
int space_id = alphabet.GetSpaceLabel();
|
DecoderState *state = new DecoderState;
|
||||||
int blank_id = alphabet.GetSize();
|
state->space_id = alphabet.GetSpaceLabel();
|
||||||
|
state->blank_id = alphabet.GetSize();
|
||||||
|
|
||||||
// init prefixes' root
|
// init prefixes' root
|
||||||
PathTrie root;
|
PathTrie *root = new PathTrie;
|
||||||
root.score = root.log_prob_b_prev = 0.0;
|
root->score = root->log_prob_b_prev = 0.0;
|
||||||
std::vector<PathTrie *> prefixes;
|
|
||||||
prefixes.push_back(&root);
|
state->prefix_root = root;
|
||||||
|
|
||||||
|
state->prefixes.push_back(root);
|
||||||
|
|
||||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
||||||
root.set_dictionary(dict_ptr);
|
root->set_dictionary(dict_ptr);
|
||||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
root.set_matcher(matcher);
|
root->set_matcher(matcher);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
// prefix search over time
|
void decoder_next(const double *probs,
|
||||||
|
const Alphabet &alphabet,
|
||||||
|
DecoderState *state,
|
||||||
|
int time_dim,
|
||||||
|
int class_dim,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
size_t beam_size,
|
||||||
|
Scorer *ext_scorer) {
|
||||||
|
|
||||||
|
// prefix search over time
|
||||||
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
|
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
|
||||||
auto *prob = &probs[time_step*class_dim];
|
auto *prob = &probs[time_step*class_dim];
|
||||||
|
|
||||||
float min_cutoff = -NUM_FLT_INF;
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
bool full_beam = false;
|
bool full_beam = false;
|
||||||
if (ext_scorer != nullptr) {
|
if (ext_scorer != nullptr) {
|
||||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||||
std::sort(
|
std::sort(
|
||||||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
min_cutoff = prefixes[num_prefixes - 1]->score +
|
|
||||||
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
|
min_cutoff = state->prefixes[num_prefixes - 1]->score +
|
||||||
|
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||||
full_beam = (num_prefixes == beam_size);
|
full_beam = (num_prefixes == beam_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
auto c = log_prob_idx[index].first;
|
auto c = log_prob_idx[index].first;
|
||||||
auto log_prob_c = log_prob_idx[index].second;
|
auto log_prob_c = log_prob_idx[index].second;
|
||||||
|
|
||||||
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
|
||||||
auto prefix = prefixes[i];
|
auto prefix = state->prefixes[i];
|
||||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// blank
|
// blank
|
||||||
if (c == blank_id) {
|
if (c == state->blank_id) {
|
||||||
prefix->log_prob_b_cur =
|
prefix->log_prob_b_cur =
|
||||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// repeated character
|
// repeated character
|
||||||
if (c == prefix->character) {
|
if (c == prefix->character) {
|
||||||
prefix->log_prob_nb_cur = log_sum_exp(
|
prefix->log_prob_nb_cur = log_sum_exp(
|
||||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get new prefix
|
// get new prefix
|
||||||
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
|
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
|
||||||
|
|
||||||
@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
|
|
||||||
// language model scoring
|
// language model scoring
|
||||||
if (ext_scorer != nullptr &&
|
if (ext_scorer != nullptr &&
|
||||||
(c == space_id || ext_scorer->is_character_based())) {
|
(c == state->space_id || ext_scorer->is_character_based())) {
|
||||||
PathTrie *prefix_to_score = nullptr;
|
PathTrie *prefix_to_score = nullptr;
|
||||||
// skip scoring the space
|
// skip scoring the space
|
||||||
if (ext_scorer->is_character_based()) {
|
if (ext_scorer->is_character_based()) {
|
||||||
@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
log_p += score;
|
log_p += score;
|
||||||
log_p += ext_scorer->beta;
|
log_p += ext_scorer->beta;
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix_new->log_prob_nb_cur =
|
prefix_new->log_prob_nb_cur =
|
||||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
}
|
}
|
||||||
} // end of loop over prefix
|
} // end of loop over prefix
|
||||||
} // end of loop over vocabulary
|
} // end of loop over vocabulary
|
||||||
|
|
||||||
|
|
||||||
prefixes.clear();
|
|
||||||
// update log probs
|
// update log probs
|
||||||
root.iterate_to_vec(prefixes);
|
state->prefixes.clear();
|
||||||
|
state->prefix_root->iterate_to_vec(state->prefixes);
|
||||||
|
|
||||||
// only preserve top beam_size prefixes
|
// only preserve top beam_size prefixes
|
||||||
if (prefixes.size() >= beam_size) {
|
if (state->prefixes.size() >= beam_size) {
|
||||||
std::nth_element(prefixes.begin(),
|
std::nth_element(state->prefixes.begin(),
|
||||||
prefixes.begin() + beam_size,
|
state->prefixes.begin() + beam_size,
|
||||||
prefixes.end(),
|
state->prefixes.end(),
|
||||||
prefix_compare);
|
prefix_compare);
|
||||||
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
|
||||||
prefixes[i]->remove();
|
state->prefixes[i]->remove();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end of loop over time
|
} // end of loop over time
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Output> decoder_decode(DecoderState *state,
|
||||||
|
const Alphabet &alphabet,
|
||||||
|
size_t beam_size,
|
||||||
|
Scorer* ext_scorer) {
|
||||||
|
|
||||||
// score the last word of each prefix that doesn't end with space
|
// score the last word of each prefix that doesn't end with space
|
||||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
|
||||||
auto prefix = prefixes[i];
|
auto prefix = state->prefixes[i];
|
||||||
if (!prefix->is_empty() && prefix->character != space_id) {
|
if (!prefix->is_empty() && prefix->character != state->space_id) {
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||||
@ -151,17 +173,17 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||||
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
|
||||||
// compute aproximate ctc score as the return score, without affecting the
|
// compute aproximate ctc score as the return score, without affecting the
|
||||||
// return order of decoding result. To delete when decoder gets stable.
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
|
||||||
double approx_ctc = prefixes[i]->score;
|
double approx_ctc = state->prefixes[i]->score;
|
||||||
if (ext_scorer != nullptr) {
|
if (ext_scorer != nullptr) {
|
||||||
std::vector<int> output;
|
std::vector<int> output;
|
||||||
std::vector<int> timesteps;
|
std::vector<int> timesteps;
|
||||||
prefixes[i]->get_path_vec(output, timesteps);
|
state->prefixes[i]->get_path_vec(output, timesteps);
|
||||||
auto prefix_length = output.size();
|
auto prefix_length = output.size();
|
||||||
auto words = ext_scorer->split_labels(output);
|
auto words = ext_scorer->split_labels(output);
|
||||||
// remove word insert
|
// remove word insert
|
||||||
@ -169,12 +191,30 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
// remove language model weight:
|
// remove language model weight:
|
||||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||||
}
|
}
|
||||||
prefixes[i]->approx_ctc = approx_ctc;
|
state->prefixes[i]->approx_ctc = approx_ctc;
|
||||||
}
|
}
|
||||||
|
|
||||||
return get_beam_search_result(prefixes, beam_size);
|
return get_beam_search_result(state->prefixes, beam_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Output> ctc_beam_search_decoder(
|
||||||
|
const double *probs,
|
||||||
|
int time_dim,
|
||||||
|
int class_dim,
|
||||||
|
const Alphabet &alphabet,
|
||||||
|
size_t beam_size,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer) {
|
||||||
|
|
||||||
|
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
|
||||||
|
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
|
||||||
|
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
|
||||||
|
|
||||||
|
delete state;
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Output>>
|
std::vector<std::vector<Output>>
|
||||||
ctc_beam_search_decoder_batch(
|
ctc_beam_search_decoder_batch(
|
||||||
|
@ -7,12 +7,73 @@
|
|||||||
#include "scorer.h"
|
#include "scorer.h"
|
||||||
#include "output.h"
|
#include "output.h"
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
#include "decoderstate.h"
|
||||||
|
|
||||||
/* CTC Beam Search Decoder
|
/* Initialize CTC beam search decoder
|
||||||
|
|
||||||
* Parameters:
|
* Parameters:
|
||||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
* alphabet: The alphabet.
|
||||||
* over alphabet of one time step.
|
* class_dim: Alphabet length (plus 1 for space character).
|
||||||
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
|
* n-gram language model scoring and word insertion term.
|
||||||
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* Return:
|
||||||
|
* A struct containing prefixes and state variables.
|
||||||
|
*/
|
||||||
|
DecoderState* decoder_init(const Alphabet &alphabet,
|
||||||
|
int class_dim,
|
||||||
|
Scorer *ext_scorer);
|
||||||
|
|
||||||
|
/* Send data to the decoder
|
||||||
|
|
||||||
|
* Parameters:
|
||||||
|
* probs: 2-D vector where each element is a vector of probabilities
|
||||||
|
* over alphabet of one time step.
|
||||||
|
* alphabet: The alphabet.
|
||||||
|
* state: The state structure previously obtained from decoder_init().
|
||||||
|
* time_dim: Number of timesteps.
|
||||||
|
* class_dim: Alphabet length (plus 1 for space character).
|
||||||
|
* cutoff_prob: Cutoff probability for pruning.
|
||||||
|
* cutoff_top_n: Cutoff number for pruning.
|
||||||
|
* beam_size: The width of beam search.
|
||||||
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
|
* n-gram language model scoring and word insertion term.
|
||||||
|
* Default null, decoding the input sample without scorer.
|
||||||
|
*/
|
||||||
|
void decoder_next(const double *probs,
|
||||||
|
const Alphabet &alphabet,
|
||||||
|
DecoderState *state,
|
||||||
|
int time_dim,
|
||||||
|
int class_dim,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
size_t beam_size,
|
||||||
|
Scorer *ext_scorer);
|
||||||
|
|
||||||
|
/* Get transcription for the data you sent via decoder_next()
|
||||||
|
|
||||||
|
* Parameters:
|
||||||
|
* state: The state structure previously obtained from decoder_init().
|
||||||
|
* alphabet: The alphabet.
|
||||||
|
* beam_size: The width of beam search.
|
||||||
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
|
* n-gram language model scoring and word insertion term.
|
||||||
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* Return:
|
||||||
|
* A vector where each element is a pair of score and decoding result,
|
||||||
|
* in descending order.
|
||||||
|
*/
|
||||||
|
std::vector<Output> decoder_decode(DecoderState *state,
|
||||||
|
const Alphabet &alphabet,
|
||||||
|
size_t beam_size,
|
||||||
|
Scorer* ext_scorer);
|
||||||
|
|
||||||
|
/* CTC Beam Search Decoder
|
||||||
|
* Parameters:
|
||||||
|
* probs: 2-D vector where each element is a vector of probabilities
|
||||||
|
* over alphabet of one time step.
|
||||||
|
* time_dim: Number of timesteps.
|
||||||
|
* class_dim: Alphabet length (plus 1 for space character).
|
||||||
* alphabet: The alphabet.
|
* alphabet: The alphabet.
|
||||||
* beam_size: The width of beam search.
|
* beam_size: The width of beam search.
|
||||||
* cutoff_prob: Cutoff probability for pruning.
|
* cutoff_prob: Cutoff probability for pruning.
|
||||||
@ -21,8 +82,8 @@
|
|||||||
* n-gram language model scoring and word insertion term.
|
* n-gram language model scoring and word insertion term.
|
||||||
* Default null, decoding the input sample without scorer.
|
* Default null, decoding the input sample without scorer.
|
||||||
* Return:
|
* Return:
|
||||||
* A vector that each element is a pair of score and decoding result,
|
* A vector where each element is a pair of score and decoding result,
|
||||||
* in desending order.
|
* in descending order.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
std::vector<Output> ctc_beam_search_decoder(
|
std::vector<Output> ctc_beam_search_decoder(
|
||||||
@ -36,9 +97,8 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
Scorer *ext_scorer);
|
Scorer *ext_scorer);
|
||||||
|
|
||||||
/* CTC Beam Search Decoder for batch data
|
/* CTC Beam Search Decoder for batch data
|
||||||
|
|
||||||
* Parameters:
|
* Parameters:
|
||||||
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
|
* probs: 3-D vector where each element is a 2-D vector that can be used
|
||||||
* by ctc_beam_search_decoder().
|
* by ctc_beam_search_decoder().
|
||||||
* alphabet: The alphabet.
|
* alphabet: The alphabet.
|
||||||
* beam_size: The width of beam search.
|
* beam_size: The width of beam search.
|
||||||
@ -49,7 +109,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
* n-gram language model scoring and word insertion term.
|
* n-gram language model scoring and word insertion term.
|
||||||
* Default null, decoding the input sample without scorer.
|
* Default null, decoding the input sample without scorer.
|
||||||
* Return:
|
* Return:
|
||||||
* A 2-D vector that each element is a vector of beam search decoding
|
* A 2-D vector where each element is a vector of beam search decoding
|
||||||
* result for one audio sample.
|
* result for one audio sample.
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<Output>>
|
std::vector<std::vector<Output>>
|
||||||
|
22
native_client/ctcdecode/decoderstate.h
Normal file
22
native_client/ctcdecode/decoderstate.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#ifndef DECODERSTATE_H_
|
||||||
|
#define DECODERSTATE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
|
||||||
|
|
||||||
|
struct DecoderState {
|
||||||
|
int space_id;
|
||||||
|
int blank_id;
|
||||||
|
std::vector<PathTrie*> prefixes;
|
||||||
|
PathTrie *prefix_root;
|
||||||
|
|
||||||
|
~DecoderState() {
|
||||||
|
if (prefix_root != nullptr) {
|
||||||
|
delete prefix_root;
|
||||||
|
}
|
||||||
|
prefix_root = nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // DECODERSTATE_H_
|
@ -75,13 +75,12 @@ using std::vector;
|
|||||||
API. When audio_buffer is full, features are computed from it and pushed to
|
API. When audio_buffer is full, features are computed from it and pushed to
|
||||||
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
|
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
|
||||||
When batch_buffer is full, we do a single step through the acoustic model
|
When batch_buffer is full, we do a single step through the acoustic model
|
||||||
and accumulate results in StreamingState::accumulated_logits.
|
and accumulate results in the DecoderState structure.
|
||||||
|
|
||||||
When fininshStream() is called, we decode the accumulated logits and return
|
When finishStream() is called, we decode the accumulated logits and return
|
||||||
the corresponding transcription.
|
the corresponding transcription.
|
||||||
*/
|
*/
|
||||||
struct StreamingState {
|
struct StreamingState {
|
||||||
vector<float> accumulated_logits;
|
|
||||||
vector<float> audio_buffer;
|
vector<float> audio_buffer;
|
||||||
vector<float> mfcc_buffer;
|
vector<float> mfcc_buffer;
|
||||||
vector<float> batch_buffer;
|
vector<float> batch_buffer;
|
||||||
@ -113,6 +112,7 @@ struct ModelState {
|
|||||||
unsigned int ncontext;
|
unsigned int ncontext;
|
||||||
Alphabet* alphabet;
|
Alphabet* alphabet;
|
||||||
Scorer* scorer;
|
Scorer* scorer;
|
||||||
|
DecoderState* decoder_state;
|
||||||
unsigned int beam_width;
|
unsigned int beam_width;
|
||||||
unsigned int n_steps;
|
unsigned int n_steps;
|
||||||
unsigned int n_context;
|
unsigned int n_context;
|
||||||
@ -145,34 +145,26 @@ struct ModelState {
|
|||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
* CTC decoder with KenLM enabled
|
* CTC decoder with KenLM enabled
|
||||||
*
|
*
|
||||||
* @param logits Flat matrix of logits, of size:
|
|
||||||
* n_frames * batch_size * num_classes
|
|
||||||
*
|
|
||||||
* @return String representing the decoded text.
|
* @return String representing the decoded text.
|
||||||
*/
|
*/
|
||||||
char* decode(const vector<float>& logits);
|
char* decode();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
* CTC decoder with KenLM enabled
|
* CTC decoder with KenLM enabled
|
||||||
*
|
*
|
||||||
* @param logits Flat matrix of logits, of size:
|
|
||||||
* n_frames * batch_size * num_classes
|
|
||||||
*
|
|
||||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||||
*/
|
*/
|
||||||
vector<Output> decode_raw(const vector<float>& logits);
|
vector<Output> decode_raw();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return character-level metadata including letter timings.
|
* @brief Return character-level metadata including letter timings.
|
||||||
*
|
*
|
||||||
* @param logits Flat matrix of logits, of size:
|
|
||||||
* n_frames * batch_size * num_classes
|
|
||||||
*
|
*
|
||||||
* @return Metadata struct containing MetadataItem structs for each character.
|
* @return Metadata struct containing MetadataItem structs for each character.
|
||||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||||
*/
|
*/
|
||||||
Metadata* decode_metadata(const vector<float>& logits);
|
Metadata* decode_metadata();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Do a single inference step in the acoustic model, with:
|
* @brief Do a single inference step in the acoustic model, with:
|
||||||
@ -202,6 +194,7 @@ ModelState::ModelState()
|
|||||||
, ncontext(0)
|
, ncontext(0)
|
||||||
, alphabet(nullptr)
|
, alphabet(nullptr)
|
||||||
, scorer(nullptr)
|
, scorer(nullptr)
|
||||||
|
, decoder_state(nullptr)
|
||||||
, beam_width(0)
|
, beam_width(0)
|
||||||
, n_steps(-1)
|
, n_steps(-1)
|
||||||
, n_context(-1)
|
, n_context(-1)
|
||||||
@ -232,6 +225,11 @@ ModelState::~ModelState()
|
|||||||
|
|
||||||
delete scorer;
|
delete scorer;
|
||||||
delete alphabet;
|
delete alphabet;
|
||||||
|
|
||||||
|
if (decoder_state != nullptr) {
|
||||||
|
delete decoder_state;
|
||||||
|
decoder_state = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -270,21 +268,21 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
char*
|
char*
|
||||||
StreamingState::intermediateDecode()
|
StreamingState::intermediateDecode()
|
||||||
{
|
{
|
||||||
return model->decode(accumulated_logits);
|
return model->decode();
|
||||||
}
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
StreamingState::finishStream()
|
StreamingState::finishStream()
|
||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
return model->decode(accumulated_logits);
|
return model->decode();
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
StreamingState::finishStreamWithMetadata()
|
StreamingState::finishStreamWithMetadata()
|
||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
return model->decode_metadata(accumulated_logits);
|
return model->decode_metadata();
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -372,7 +370,26 @@ StreamingState::processMfccWindow(const vector<float>& buf)
|
|||||||
void
|
void
|
||||||
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||||
{
|
{
|
||||||
model->infer(buf.data(), n_steps, accumulated_logits);
|
vector<float> logits;
|
||||||
|
model->infer(buf.data(), n_steps, logits);
|
||||||
|
|
||||||
|
const int cutoff_top_n = 40;
|
||||||
|
const double cutoff_prob = 1.0;
|
||||||
|
const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank
|
||||||
|
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
|
||||||
|
|
||||||
|
// Convert logits to double
|
||||||
|
vector<double> inputs(logits.begin(), logits.end());
|
||||||
|
|
||||||
|
decoder_next(inputs.data(),
|
||||||
|
*model->alphabet,
|
||||||
|
model->decoder_state,
|
||||||
|
n_frames,
|
||||||
|
num_classes,
|
||||||
|
cutoff_prob,
|
||||||
|
cutoff_top_n,
|
||||||
|
model->beam_width,
|
||||||
|
model->scorer);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -507,35 +524,24 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
ModelState::decode(const vector<float>& logits)
|
ModelState::decode()
|
||||||
{
|
{
|
||||||
vector<Output> out = ModelState::decode_raw(logits);
|
vector<Output> out = ModelState::decode_raw();
|
||||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Output>
|
vector<Output>
|
||||||
ModelState::decode_raw(const vector<float>& logits)
|
ModelState::decode_raw()
|
||||||
{
|
{
|
||||||
const int cutoff_top_n = 40;
|
vector<Output> out = decoder_decode(decoder_state, *alphabet, beam_width, scorer);
|
||||||
const double cutoff_prob = 1.0;
|
|
||||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
|
||||||
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
|
|
||||||
|
|
||||||
// Convert logits to double
|
|
||||||
vector<double> inputs(logits.begin(), logits.end());
|
|
||||||
|
|
||||||
// Vector of <probability, Output> pairs
|
|
||||||
vector<Output> out = ctc_beam_search_decoder(
|
|
||||||
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
|
|
||||||
cutoff_prob, cutoff_top_n, scorer);
|
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
ModelState::decode_metadata(const vector<float>& logits)
|
ModelState::decode_metadata()
|
||||||
{
|
{
|
||||||
vector<Output> out = decode_raw(logits);
|
vector<Output> out = decode_raw();
|
||||||
|
|
||||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||||
metadata->num_items = out[0].tokens.size();
|
metadata->num_items = out[0].tokens.size();
|
||||||
@ -830,8 +836,6 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
aPreAllocFrames = 150;
|
aPreAllocFrames = 150;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
|
|
||||||
|
|
||||||
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
||||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
||||||
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
||||||
@ -839,6 +843,9 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
|
|
||||||
ctx->model = aCtx;
|
ctx->model = aCtx;
|
||||||
|
|
||||||
|
DecoderState *params = decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer);
|
||||||
|
aCtx->decoder_state = params;
|
||||||
|
|
||||||
*retval = ctx.release();
|
*retval = ctx.release();
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user