Don't explicitly score the BOS token, and avoid copies when scoring sentences

This commit is contained in:
Reuben Morais 2019-09-12 15:15:11 +02:00
parent 005b5a8c3b
commit 6dba6d4a95
3 changed files with 86 additions and 43 deletions

View File

@ -120,7 +120,8 @@ DecoderState::next(const double *probs,
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = ext_scorer_->make_ngram(prefix_to_score); ngram = ext_scorer_->make_ngram(prefix_to_score);
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha; bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
log_p += score; log_p += score;
log_p += ext_scorer_->beta; log_p += ext_scorer_->beta;
} }

View File

@ -147,54 +147,95 @@ void Scorer::save_dictionary(const std::string& path)
} }
} }
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) double Scorer::get_log_cond_prob(const std::vector<std::string>& words,
bool bos,
bool eos)
{ {
double cond_prob = OOV_SCORE; return get_log_cond_prob(words.begin(), words.end(), bos, eos);
lm::ngram::State state, tmp_state, out_state; }
// avoid to inserting <s> in begin
language_model_->NullContextWrite(&state); double Scorer::get_log_cond_prob(const std::vector<std::string>::const_iterator& begin,
for (size_t i = 0; i < words.size(); ++i) { const std::vector<std::string>::const_iterator& end,
lm::WordIndex word_index = language_model_->BaseVocabulary().Index(words[i]); bool bos,
bool eos)
{
const auto& vocab = language_model_->BaseVocabulary();
lm::ngram::State state_vec[2];
lm::ngram::State *in_state = &state_vec[0];
lm::ngram::State *out_state = &state_vec[1];
if (bos) {
language_model_->BeginSentenceWrite(in_state);
} else {
language_model_->NullContextWrite(in_state);
}
double cond_prob = 0.0;
for (auto it = begin; it != end; ++it) {
lm::WordIndex word_index = vocab.Index(*it);
// encounter OOV // encounter OOV
if (word_index == 0) { if (word_index == lm::kUNK) {
return OOV_SCORE; return OOV_SCORE;
} }
cond_prob = language_model_->BaseScore(&state, word_index, &out_state);
tmp_state = state; cond_prob = language_model_->BaseScore(in_state, word_index, out_state);
state = out_state; std::swap(in_state, out_state);
out_state = tmp_state;
} }
// return loge prob
if (eos) {
cond_prob = language_model_->BaseScore(in_state, vocab.EndSentence(), out_state);
}
// return loge prob
return cond_prob/NUM_FLT_LOGE; return cond_prob/NUM_FLT_LOGE;
} }
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) double Scorer::get_sent_log_prob(const std::vector<std::string>& words)
{ {
std::vector<std::string> sentence; // For a given sentence (`words`), return sum of LM scores over windows on
if (words.size() == 0) { // sentence. For example, given the sentence:
for (size_t i = 0; i < max_order_; ++i) { //
sentence.push_back(START_TOKEN); // there once was an ugly barnacle
} //
} else { // And a language model with max_order_ = 3, this function will return the sum
for (size_t i = 0; i < max_order_ - 1; ++i) { // of the following scores:
sentence.push_back(START_TOKEN); //
} // there | <s>
sentence.insert(sentence.end(), words.begin(), words.end()); // there once | <s>
} // there once was
sentence.push_back(END_TOKEN); // once was an
return get_log_prob(sentence); // was an ugly
} // an ugly barnacle
// ugly barnacle </s>
//
// This is used in the decoding process to compute the LM contribution for a
// given beam's accumulated score, so that it can be removed and only the
// acoustic model contribution can be returned as a confidence score for the
// transcription. See DecoderState::decode.
const int sent_len = words.size();
double Scorer::get_log_prob(const std::vector<std::string>& words)
{
assert(words.size() > max_order_);
double score = 0.0; double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { for (int win_start = 0, win_end = 1; win_end <= sent_len+1; ++win_end) {
std::vector<std::string> ngram(words.begin() + i, const int win_size = win_end - win_start;
words.begin() + i + max_order_); bool bos = win_size < max_order_;
score += get_log_cond_prob(ngram); bool eos = win_end == sent_len + 1;
// The last window goes one past the end of the words vector as passing the
// EOS=true flag counts towards the length of the scored sentence, so we
// adjust the win_end index here to not go over bounds.
score += get_log_cond_prob(words.begin() + win_start,
words.begin() + (eos ? win_end - 1 : win_end),
bos,
eos);
// Only increment window start position after we have a full window
if (win_size == max_order_) {
win_start++;
}
} }
return score;
return score / NUM_FLT_LOGE;
} }
void Scorer::reset_params(float alpha, float beta) void Scorer::reset_params(float alpha, float beta)
@ -240,10 +281,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
ngram.push_back(word); ngram.push_back(word);
if (new_node->character == -1) { if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break; break;
} }
} }

View File

@ -62,7 +62,14 @@ public:
const std::string &trie_path, const std::string &trie_path,
const std::string &alphabet_config_path); const std::string &alphabet_config_path);
double get_log_cond_prob(const std::vector<std::string> &words); double get_log_cond_prob(const std::vector<std::string> &words,
bool bos = false,
bool eos = false);
double get_log_cond_prob(const std::vector<std::string>::const_iterator &begin,
const std::vector<std::string>::const_iterator &end,
bool bos = false,
bool eos = false);
double get_sent_log_prob(const std::vector<std::string> &words); double get_sent_log_prob(const std::vector<std::string> &words);
@ -103,8 +110,6 @@ protected:
// fill dictionary for FST // fill dictionary for FST
void fill_dictionary(const std::vector<std::string> &vocabulary, bool add_space); void fill_dictionary(const std::vector<std::string> &vocabulary, bool add_space);
double get_log_prob(const std::vector<std::string> &words);
private: private:
std::unique_ptr<lm::base::Model> language_model_; std::unique_ptr<lm::base::Model> language_model_;
bool is_character_based_ = true; bool is_character_based_ = true;