Merge pull request #2184 from mozilla/fix-intermediate-decode

Fix for DS_IntermediateDecode modifying live prefixes in DecoderState
This commit is contained in:
Reuben Morais 2019-06-20 10:37:26 -03:00 committed by GitHub
commit 080fc27c65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 28 deletions

View File

@ -14,10 +14,11 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
DecoderState* decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer* ext_scorer) {
DecoderState*
decoder_init(const Alphabet &alphabet,
int class_dim,
Scorer* ext_scorer)
{
// dimension check
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
"The shape of probs does not match with "
@ -47,16 +48,17 @@ DecoderState* decoder_init(const Alphabet &alphabet,
return state;
}
void decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer) {
void
decoder_next(const double *probs,
const Alphabet &alphabet,
DecoderState *state,
int time_dim,
int class_dim,
double cutoff_prob,
size_t cutoff_top_n,
size_t beam_size,
Scorer *ext_scorer)
{
// prefix search over time
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
auto *prob = &probs[rel_time_step*class_dim];
@ -155,39 +157,47 @@ void decoder_next(const double *probs,
} // end of loop over time
}
std::vector<Output> decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer) {
std::vector<Output>
decoder_decode(DecoderState *state,
const Alphabet &alphabet,
size_t beam_size,
Scorer* ext_scorer)
{
std::vector<PathTrie*> prefixes_copy = state->prefixes;
std::unordered_map<const PathTrie*, float> scores;
for (PathTrie* prefix : prefixes_copy) {
scores[prefix] = prefix->score;
}
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
auto prefix = state->prefixes[i];
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i];
if (!prefix->is_empty() && prefix->character != state->space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
scores[prefix] += score;
}
}
}
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size);
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
//TODO: expose this as an API parameter
const int top_paths = 1;
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < top_paths && i < state->prefixes.size(); ++i) {
double approx_ctc = state->prefixes[i]->score;
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
double approx_ctc = scores[prefixes_copy[i]];
if (ext_scorer != nullptr) {
std::vector<int> output;
std::vector<int> timesteps;
state->prefixes[i]->get_path_vec(output, timesteps);
prefixes_copy[i]->get_path_vec(output, timesteps);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
@ -195,10 +205,10 @@ std::vector<Output> decoder_decode(DecoderState *state,
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
state->prefixes[i]->approx_ctc = approx_ctc;
prefixes_copy[i]->approx_ctc = approx_ctc;
}
return get_beam_search_result(state->prefixes, top_paths);
return get_beam_search_result(prefixes_copy, top_paths);
}
std::vector<Output> ctc_beam_search_decoder(

View File

@ -112,6 +112,18 @@ bool prefix_compare(const PathTrie *x, const PathTrie *y) {
}
}
bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map<const PathTrie*, float>& scores) {
if (scores.at(x) == scores.at(y)) {
if (x->character == y->character) {
return false;
} else {
return (x->character < y->character);
}
} else {
return scores.at(x) > scores.at(y);
}
}
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) {

View File

@ -67,6 +67,8 @@ std::vector<Output> get_beam_search_result(
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);
bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map<const PathTrie*, float>& scores);
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/

View File

@ -13,3 +13,5 @@ check_tensorflow_version
run_all_inference_tests
run_multi_inference_tests
run_cpp_only_inference_tests

View File

@ -460,6 +460,15 @@ run_multi_inference_tests()
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
}
run_cpp_only_inference_tests()
{
set +e
phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1)
status=$?
set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_intermediate_decode}" "$status"
}
android_run_tests()
{
cd ${DS_DSDIR}/native_client/java/