Clean up formatting in ctc_beam_saerch_decoder.cpp
This commit is contained in:
parent
e136b5299a
commit
d5d55958da
@ -14,10 +14,11 @@
|
|||||||
|
|
||||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
|
||||||
DecoderState* decoder_init(const Alphabet &alphabet,
|
DecoderState*
|
||||||
int class_dim,
|
decoder_init(const Alphabet &alphabet,
|
||||||
Scorer* ext_scorer) {
|
int class_dim,
|
||||||
|
Scorer* ext_scorer)
|
||||||
|
{
|
||||||
// dimension check
|
// dimension check
|
||||||
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
||||||
"The shape of probs does not match with "
|
"The shape of probs does not match with "
|
||||||
@ -47,16 +48,17 @@ DecoderState* decoder_init(const Alphabet &alphabet,
|
|||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
void decoder_next(const double *probs,
|
void
|
||||||
const Alphabet &alphabet,
|
decoder_next(const double *probs,
|
||||||
DecoderState *state,
|
const Alphabet &alphabet,
|
||||||
int time_dim,
|
DecoderState *state,
|
||||||
int class_dim,
|
int time_dim,
|
||||||
double cutoff_prob,
|
int class_dim,
|
||||||
size_t cutoff_top_n,
|
double cutoff_prob,
|
||||||
size_t beam_size,
|
size_t cutoff_top_n,
|
||||||
Scorer *ext_scorer) {
|
size_t beam_size,
|
||||||
|
Scorer *ext_scorer)
|
||||||
|
{
|
||||||
// prefix search over time
|
// prefix search over time
|
||||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||||
auto *prob = &probs[rel_time_step*class_dim];
|
auto *prob = &probs[rel_time_step*class_dim];
|
||||||
@ -155,11 +157,12 @@ void decoder_next(const double *probs,
|
|||||||
} // end of loop over time
|
} // end of loop over time
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Output> decoder_decode(DecoderState *state,
|
std::vector<Output>
|
||||||
const Alphabet &alphabet,
|
decoder_decode(DecoderState *state,
|
||||||
size_t beam_size,
|
const Alphabet &alphabet,
|
||||||
Scorer* ext_scorer) {
|
size_t beam_size,
|
||||||
|
Scorer* ext_scorer)
|
||||||
|
{
|
||||||
// score the last word of each prefix that doesn't end with space
|
// score the last word of each prefix that doesn't end with space
|
||||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
|
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user