Merge pull request #2184 from mozilla/fix-intermediate-decode
Fix for DS_IntermediateDecode modifying live prefixes in DecoderState
This commit is contained in:
commit
080fc27c65
@ -14,10 +14,11 @@
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
DecoderState* decoder_init(const Alphabet &alphabet,
|
||||
int class_dim,
|
||||
Scorer* ext_scorer) {
|
||||
|
||||
DecoderState*
|
||||
decoder_init(const Alphabet &alphabet,
|
||||
int class_dim,
|
||||
Scorer* ext_scorer)
|
||||
{
|
||||
// dimension check
|
||||
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
||||
"The shape of probs does not match with "
|
||||
@ -47,16 +48,17 @@ DecoderState* decoder_init(const Alphabet &alphabet,
|
||||
return state;
|
||||
}
|
||||
|
||||
void decoder_next(const double *probs,
|
||||
const Alphabet &alphabet,
|
||||
DecoderState *state,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer) {
|
||||
|
||||
void
|
||||
decoder_next(const double *probs,
|
||||
const Alphabet &alphabet,
|
||||
DecoderState *state,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer)
|
||||
{
|
||||
// prefix search over time
|
||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||
auto *prob = &probs[rel_time_step*class_dim];
|
||||
@ -155,39 +157,47 @@ void decoder_next(const double *probs,
|
||||
} // end of loop over time
|
||||
}
|
||||
|
||||
std::vector<Output> decoder_decode(DecoderState *state,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
Scorer* ext_scorer) {
|
||||
std::vector<Output>
|
||||
decoder_decode(DecoderState *state,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
Scorer* ext_scorer)
|
||||
{
|
||||
std::vector<PathTrie*> prefixes_copy = state->prefixes;
|
||||
std::unordered_map<const PathTrie*, float> scores;
|
||||
for (PathTrie* prefix : prefixes_copy) {
|
||||
scores[prefix] = prefix->score;
|
||||
}
|
||||
|
||||
// score the last word of each prefix that doesn't end with space
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
|
||||
auto prefix = state->prefixes[i];
|
||||
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
|
||||
auto prefix = prefixes_copy[i];
|
||||
if (!prefix->is_empty() && prefix->character != state->space_id) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
score += ext_scorer->beta;
|
||||
prefix->score += score;
|
||||
scores[prefix] += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||
using namespace std::placeholders;
|
||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size);
|
||||
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
|
||||
|
||||
//TODO: expose this as an API parameter
|
||||
const int top_paths = 1;
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < top_paths && i < state->prefixes.size(); ++i) {
|
||||
double approx_ctc = state->prefixes[i]->score;
|
||||
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
||||
double approx_ctc = scores[prefixes_copy[i]];
|
||||
if (ext_scorer != nullptr) {
|
||||
std::vector<int> output;
|
||||
std::vector<int> timesteps;
|
||||
state->prefixes[i]->get_path_vec(output, timesteps);
|
||||
prefixes_copy[i]->get_path_vec(output, timesteps);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer->split_labels(output);
|
||||
// remove word insert
|
||||
@ -195,10 +205,10 @@ std::vector<Output> decoder_decode(DecoderState *state,
|
||||
// remove language model weight:
|
||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||
}
|
||||
state->prefixes[i]->approx_ctc = approx_ctc;
|
||||
prefixes_copy[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
|
||||
return get_beam_search_result(state->prefixes, top_paths);
|
||||
return get_beam_search_result(prefixes_copy, top_paths);
|
||||
}
|
||||
|
||||
std::vector<Output> ctc_beam_search_decoder(
|
||||
|
@ -112,6 +112,18 @@ bool prefix_compare(const PathTrie *x, const PathTrie *y) {
|
||||
}
|
||||
}
|
||||
|
||||
bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map<const PathTrie*, float>& scores) {
|
||||
if (scores.at(x) == scores.at(y)) {
|
||||
if (x->character == y->character) {
|
||||
return false;
|
||||
} else {
|
||||
return (x->character < y->character);
|
||||
}
|
||||
} else {
|
||||
return scores.at(x) > scores.at(y);
|
||||
}
|
||||
}
|
||||
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
if (dictionary->NumStates() == 0) {
|
||||
|
@ -67,6 +67,8 @@ std::vector<Output> get_beam_search_result(
|
||||
// Functor for prefix comparsion
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||
|
||||
bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map<const PathTrie*, float>& scores);
|
||||
|
||||
/* Get length of utf8 encoding string
|
||||
* See: http://stackoverflow.com/a/4063229
|
||||
*/
|
||||
|
@ -13,3 +13,5 @@ check_tensorflow_version
|
||||
run_all_inference_tests
|
||||
|
||||
run_multi_inference_tests
|
||||
|
||||
run_cpp_only_inference_tests
|
||||
|
@ -460,6 +460,15 @@ run_multi_inference_tests()
|
||||
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
|
||||
}
|
||||
|
||||
run_cpp_only_inference_tests()
|
||||
{
|
||||
set +e
|
||||
phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_intermediate_decode}" "$status"
|
||||
}
|
||||
|
||||
android_run_tests()
|
||||
{
|
||||
cd ${DS_DSDIR}/native_client/java/
|
||||
|
Loading…
x
Reference in New Issue
Block a user