Merge pull request #2383 from mozilla/scorer-cleanup
Don't explicitly score the BOS token, and avoid copies when scoring sentences
This commit is contained in:
commit
ba56407376
|
@ -120,7 +120,8 @@ DecoderState::next(const double *probs,
|
|||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
||||
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
||||
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
||||
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer_->beta;
|
||||
}
|
||||
|
|
|
@ -148,54 +148,95 @@ void Scorer::save_dictionary(const std::string& path)
|
|||
}
|
||||
}
|
||||
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words)
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words,
|
||||
bool bos,
|
||||
bool eos)
|
||||
{
|
||||
double cond_prob = OOV_SCORE;
|
||||
lm::ngram::State state, tmp_state, out_state;
|
||||
// avoid to inserting <s> in begin
|
||||
language_model_->NullContextWrite(&state);
|
||||
for (size_t i = 0; i < words.size(); ++i) {
|
||||
lm::WordIndex word_index = language_model_->BaseVocabulary().Index(words[i]);
|
||||
return get_log_cond_prob(words.begin(), words.end(), bos, eos);
|
||||
}
|
||||
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>::const_iterator& begin,
|
||||
const std::vector<std::string>::const_iterator& end,
|
||||
bool bos,
|
||||
bool eos)
|
||||
{
|
||||
const auto& vocab = language_model_->BaseVocabulary();
|
||||
lm::ngram::State state_vec[2];
|
||||
lm::ngram::State *in_state = &state_vec[0];
|
||||
lm::ngram::State *out_state = &state_vec[1];
|
||||
|
||||
if (bos) {
|
||||
language_model_->BeginSentenceWrite(in_state);
|
||||
} else {
|
||||
language_model_->NullContextWrite(in_state);
|
||||
}
|
||||
|
||||
double cond_prob = 0.0;
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
lm::WordIndex word_index = vocab.Index(*it);
|
||||
|
||||
// encounter OOV
|
||||
if (word_index == 0) {
|
||||
if (word_index == lm::kUNK) {
|
||||
return OOV_SCORE;
|
||||
}
|
||||
cond_prob = language_model_->BaseScore(&state, word_index, &out_state);
|
||||
tmp_state = state;
|
||||
state = out_state;
|
||||
out_state = tmp_state;
|
||||
|
||||
cond_prob = language_model_->BaseScore(in_state, word_index, out_state);
|
||||
std::swap(in_state, out_state);
|
||||
}
|
||||
|
||||
if (eos) {
|
||||
cond_prob = language_model_->BaseScore(in_state, vocab.EndSentence(), out_state);
|
||||
}
|
||||
|
||||
// return loge prob
|
||||
return cond_prob/NUM_FLT_LOGE;
|
||||
}
|
||||
|
||||
double Scorer::get_sent_log_prob(const std::vector<std::string>& words)
|
||||
{
|
||||
std::vector<std::string> sentence;
|
||||
if (words.size() == 0) {
|
||||
for (size_t i = 0; i < max_order_; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
// For a given sentence (`words`), return sum of LM scores over windows on
|
||||
// sentence. For example, given the sentence:
|
||||
//
|
||||
// there once was an ugly barnacle
|
||||
//
|
||||
// And a language model with max_order_ = 3, this function will return the sum
|
||||
// of the following scores:
|
||||
//
|
||||
// there | <s>
|
||||
// there once | <s>
|
||||
// there once was
|
||||
// once was an
|
||||
// was an ugly
|
||||
// an ugly barnacle
|
||||
// ugly barnacle </s>
|
||||
//
|
||||
// This is used in the decoding process to compute the LM contribution for a
|
||||
// given beam's accumulated score, so that it can be removed and only the
|
||||
// acoustic model contribution can be returned as a confidence score for the
|
||||
// transcription. See DecoderState::decode.
|
||||
const int sent_len = words.size();
|
||||
|
||||
double score = 0.0;
|
||||
for (int win_start = 0, win_end = 1; win_end <= sent_len+1; ++win_end) {
|
||||
const int win_size = win_end - win_start;
|
||||
bool bos = win_size < max_order_;
|
||||
bool eos = win_end == (sent_len + 1);
|
||||
|
||||
// The last window goes one past the end of the words vector as passing the
|
||||
// EOS=true flag counts towards the length of the scored sentence, so we
|
||||
// adjust the win_end index here to not go over bounds.
|
||||
score += get_log_cond_prob(words.begin() + win_start,
|
||||
words.begin() + (eos ? win_end - 1 : win_end),
|
||||
bos,
|
||||
eos);
|
||||
|
||||
// Only increment window start position after we have a full window
|
||||
if (win_size == max_order_) {
|
||||
win_start++;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < max_order_ - 1; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||
}
|
||||
sentence.push_back(END_TOKEN);
|
||||
return get_log_prob(sentence);
|
||||
}
|
||||
|
||||
double Scorer::get_log_prob(const std::vector<std::string>& words)
|
||||
{
|
||||
assert(words.size() > max_order_);
|
||||
double score = 0.0;
|
||||
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
|
||||
std::vector<std::string> ngram(words.begin() + i,
|
||||
words.begin() + i + max_order_);
|
||||
score += get_log_cond_prob(ngram);
|
||||
}
|
||||
return score;
|
||||
return score / NUM_FLT_LOGE;
|
||||
}
|
||||
|
||||
void Scorer::reset_params(float alpha, float beta)
|
||||
|
@ -241,10 +282,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||
ngram.push_back(word);
|
||||
|
||||
if (new_node->character == -1) {
|
||||
// No more spaces, but still need order
|
||||
for (int i = 0; i < max_order_ - order - 1; i++) {
|
||||
ngram.push_back(START_TOKEN);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,7 +62,14 @@ public:
|
|||
const std::string &trie_path,
|
||||
const std::string &alphabet_config_path);
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string> &words);
|
||||
double get_log_cond_prob(const std::vector<std::string> &words,
|
||||
bool bos = false,
|
||||
bool eos = false);
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string>::const_iterator &begin,
|
||||
const std::vector<std::string>::const_iterator &end,
|
||||
bool bos = false,
|
||||
bool eos = false);
|
||||
|
||||
double get_sent_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
|
@ -103,8 +110,6 @@ protected:
|
|||
// fill dictionary for FST
|
||||
void fill_dictionary(const std::vector<std::string> &vocabulary, bool add_space);
|
||||
|
||||
double get_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
private:
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
bool is_character_based_ = true;
|
||||
|
|
Loading…
Reference in New Issue