Don't explicitly score the BOS token, and avoid copies when scoring sentences

This commit is contained in:
Reuben Morais 2019-09-12 15:15:11 +02:00
parent 005b5a8c3b
commit 6dba6d4a95
3 changed files with 86 additions and 43 deletions

View File

@ -120,7 +120,8 @@ DecoderState::next(const double *probs,
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer_->make_ngram(prefix_to_score);
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
log_p += score;
log_p += ext_scorer_->beta;
}

View File

@ -147,54 +147,95 @@ void Scorer::save_dictionary(const std::string& path)
}
}
double Scorer::get_log_cond_prob(const std::vector<std::string>& words)
double Scorer::get_log_cond_prob(const std::vector<std::string>& words,
bool bos,
bool eos)
{
double cond_prob = OOV_SCORE;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
language_model_->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = language_model_->BaseVocabulary().Index(words[i]);
return get_log_cond_prob(words.begin(), words.end(), bos, eos);
}
double Scorer::get_log_cond_prob(const std::vector<std::string>::const_iterator& begin,
const std::vector<std::string>::const_iterator& end,
bool bos,
bool eos)
{
const auto& vocab = language_model_->BaseVocabulary();
lm::ngram::State state_vec[2];
lm::ngram::State *in_state = &state_vec[0];
lm::ngram::State *out_state = &state_vec[1];
if (bos) {
language_model_->BeginSentenceWrite(in_state);
} else {
language_model_->NullContextWrite(in_state);
}
double cond_prob = 0.0;
for (auto it = begin; it != end; ++it) {
lm::WordIndex word_index = vocab.Index(*it);
// encounter OOV
if (word_index == 0) {
if (word_index == lm::kUNK) {
return OOV_SCORE;
}
cond_prob = language_model_->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
cond_prob = language_model_->BaseScore(in_state, word_index, out_state);
std::swap(in_state, out_state);
}
// return loge prob
if (eos) {
cond_prob = language_model_->BaseScore(in_state, vocab.EndSentence(), out_state);
}
// return loge prob
return cond_prob/NUM_FLT_LOGE;
}
double Scorer::get_sent_log_prob(const std::vector<std::string>& words)
{
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
}
// For a given sentence (`words`), return sum of LM scores over windows on
// sentence. For example, given the sentence:
//
// there once was an ugly barnacle
//
// And a language model with max_order_ = 3, this function will return the sum
// of the following scores:
//
// there | <s>
// there once | <s>
// there once was
// once was an
// was an ugly
// an ugly barnacle
// ugly barnacle </s>
//
// This is used in the decoding process to compute the LM contribution for a
// given beam's accumulated score, so that it can be removed and only the
// acoustic model contribution can be returned as a confidence score for the
// transcription. See DecoderState::decode.
const int sent_len = words.size();
double Scorer::get_log_prob(const std::vector<std::string>& words)
{
assert(words.size() > max_order_);
double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + max_order_);
score += get_log_cond_prob(ngram);
for (int win_start = 0, win_end = 1; win_end <= sent_len+1; ++win_end) {
const int win_size = win_end - win_start;
bool bos = win_size < max_order_;
bool eos = win_end == sent_len + 1;
// The last window goes one past the end of the words vector as passing the
// EOS=true flag counts towards the length of the scored sentence, so we
// adjust the win_end index here to not go over bounds.
score += get_log_cond_prob(words.begin() + win_start,
words.begin() + (eos ? win_end - 1 : win_end),
bos,
eos);
// Only increment window start position after we have a full window
if (win_size == max_order_) {
win_start++;
}
}
return score;
return score / NUM_FLT_LOGE;
}
void Scorer::reset_params(float alpha, float beta)
@ -240,10 +281,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}

View File

@ -62,7 +62,14 @@ public:
const std::string &trie_path,
const std::string &alphabet_config_path);
double get_log_cond_prob(const std::vector<std::string> &words);
double get_log_cond_prob(const std::vector<std::string> &words,
bool bos = false,
bool eos = false);
double get_log_cond_prob(const std::vector<std::string>::const_iterator &begin,
const std::vector<std::string>::const_iterator &end,
bool bos = false,
bool eos = false);
double get_sent_log_prob(const std::vector<std::string> &words);
@ -103,8 +110,6 @@ protected:
// fill dictionary for FST
void fill_dictionary(const std::vector<std::string> &vocabulary, bool add_space);
double get_log_prob(const std::vector<std::string> &words);
private:
std::unique_ptr<lm::base::Model> language_model_;
bool is_character_based_ = true;