diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index c33747f9..f6ec3082 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -120,7 +120,8 @@ DecoderState::next(const double *probs, float score = 0.0; std::vector ngram; ngram = ext_scorer_->make_ngram(prefix_to_score); - score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha; + bool bos = ngram.size() < ext_scorer_->get_max_order(); + score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha; log_p += score; log_p += ext_scorer_->beta; } diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 55717d60..ae0b98c0 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -147,54 +147,95 @@ void Scorer::save_dictionary(const std::string& path) } } -double Scorer::get_log_cond_prob(const std::vector& words) +double Scorer::get_log_cond_prob(const std::vector& words, + bool bos, + bool eos) { - double cond_prob = OOV_SCORE; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - language_model_->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = language_model_->BaseVocabulary().Index(words[i]); + return get_log_cond_prob(words.begin(), words.end(), bos, eos); +} + +double Scorer::get_log_cond_prob(const std::vector::const_iterator& begin, + const std::vector::const_iterator& end, + bool bos, + bool eos) +{ + const auto& vocab = language_model_->BaseVocabulary(); + lm::ngram::State state_vec[2]; + lm::ngram::State *in_state = &state_vec[0]; + lm::ngram::State *out_state = &state_vec[1]; + + if (bos) { + language_model_->BeginSentenceWrite(in_state); + } else { + language_model_->NullContextWrite(in_state); + } + + double cond_prob = 0.0; + for (auto it = begin; it != end; ++it) { + lm::WordIndex word_index = vocab.Index(*it); + // encounter OOV - if (word_index == 0) { + if (word_index == lm::kUNK) { return OOV_SCORE; } - cond_prob = language_model_->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; + + cond_prob = language_model_->BaseScore(in_state, word_index, out_state); + std::swap(in_state, out_state); } - // return loge prob + + if (eos) { + cond_prob = language_model_->BaseScore(in_state, vocab.EndSentence(), out_state); + } + + // return loge prob return cond_prob/NUM_FLT_LOGE; } double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < max_order_; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < max_order_ - 1; ++i) { - sentence.push_back(START_TOKEN); - } - sentence.insert(sentence.end(), words.begin(), words.end()); - } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); -} + // For a given sentence (`words`), return sum of LM scores over windows on + // sentence. For example, given the sentence: + // + // there once was an ugly barnacle + // + // And a language model with max_order_ = 3, this function will return the sum + // of the following scores: + // + // there | + // there once | + // there once was + // once was an + // was an ugly + // an ugly barnacle + // ugly barnacle + // + // This is used in the decoding process to compute the LM contribution for a + // given beam's accumulated score, so that it can be removed and only the + // acoustic model contribution can be returned as a confidence score for the + // transcription. See DecoderState::decode. + const int sent_len = words.size(); -double Scorer::get_log_prob(const std::vector& words) -{ - assert(words.size() > max_order_); double score = 0.0; - for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + max_order_); - score += get_log_cond_prob(ngram); + for (int win_start = 0, win_end = 1; win_end <= sent_len+1; ++win_end) { + const int win_size = win_end - win_start; + bool bos = win_size < max_order_; + bool eos = win_end == sent_len + 1; + + // The last window goes one past the end of the words vector as passing the + // EOS=true flag counts towards the length of the scored sentence, so we + // adjust the win_end index here to not go over bounds. + score += get_log_cond_prob(words.begin() + win_start, + words.begin() + (eos ? win_end - 1 : win_end), + bos, + eos); + + // Only increment window start position after we have a full window + if (win_size == max_order_) { + win_start++; + } } - return score; + + return score / NUM_FLT_LOGE; } void Scorer::reset_params(float alpha, float beta) @@ -240,10 +281,6 @@ std::vector Scorer::make_ngram(PathTrie* prefix) ngram.push_back(word); if (new_node->character == -1) { - // No more spaces, but still need order - for (int i = 0; i < max_order_ - order - 1; i++) { - ngram.push_back(START_TOKEN); - } break; } } diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 28aeb65e..5540138c 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -62,7 +62,14 @@ public: const std::string &trie_path, const std::string &alphabet_config_path); - double get_log_cond_prob(const std::vector &words); + double get_log_cond_prob(const std::vector &words, + bool bos = false, + bool eos = false); + + double get_log_cond_prob(const std::vector::const_iterator &begin, + const std::vector::const_iterator &end, + bool bos = false, + bool eos = false); double get_sent_log_prob(const std::vector &words); @@ -103,8 +110,6 @@ protected: // fill dictionary for FST void fill_dictionary(const std::vector &vocabulary, bool add_space); - double get_log_prob(const std::vector &words); - private: std::unique_ptr language_model_; bool is_character_based_ = true;