Clean up formatting in ctc_beam_saerch_decoder.cpp

This commit is contained in:
Reuben Morais 2019-06-18 12:08:00 -03:00
parent e136b5299a
commit d5d55958da

View File

@ -14,10 +14,11 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
DecoderState* decoder_init(const Alphabet &alphabet, DecoderState*
int class_dim, decoder_init(const Alphabet &alphabet,
Scorer* ext_scorer) { int class_dim,
Scorer* ext_scorer)
{
// dimension check // dimension check
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1, VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with " "The shape of probs does not match with "
@ -47,16 +48,17 @@ DecoderState* decoder_init(const Alphabet &alphabet,
return state; return state;
} }
void decoder_next(const double *probs, void
const Alphabet &alphabet, decoder_next(const double *probs,
DecoderState *state, const Alphabet &alphabet,
int time_dim, DecoderState *state,
int class_dim, int time_dim,
double cutoff_prob, int class_dim,
size_t cutoff_top_n, double cutoff_prob,
size_t beam_size, size_t cutoff_top_n,
Scorer *ext_scorer) { size_t beam_size,
Scorer *ext_scorer)
{
// prefix search over time // prefix search over time
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) { for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
auto *prob = &probs[rel_time_step*class_dim]; auto *prob = &probs[rel_time_step*class_dim];
@ -155,11 +157,12 @@ void decoder_next(const double *probs,
} // end of loop over time } // end of loop over time
} }
std::vector<Output> decoder_decode(DecoderState *state, std::vector<Output>
const Alphabet &alphabet, decoder_decode(DecoderState *state,
size_t beam_size, const Alphabet &alphabet,
Scorer* ext_scorer) { size_t beam_size,
Scorer* ext_scorer)
{
// score the last word of each prefix that doesn't end with space // score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {