Address review comment and add missing check for presence of scorer
This commit is contained in:
parent
0e6952c3a8
commit
d2eb305b73
@ -109,6 +109,7 @@ DecoderState::next(const double *probs,
|
|||||||
log_p = log_prob_c + prefix->score;
|
log_p = log_prob_c + prefix->score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ext_scorer_ != nullptr) {
|
||||||
// skip scoring the space in word based LMs
|
// skip scoring the space in word based LMs
|
||||||
PathTrie* prefix_to_score;
|
PathTrie* prefix_to_score;
|
||||||
if (ext_scorer_->is_utf8_mode()) {
|
if (ext_scorer_->is_utf8_mode()) {
|
||||||
@ -117,12 +118,8 @@ DecoderState::next(const double *probs,
|
|||||||
prefix_to_score = prefix;
|
prefix_to_score = prefix;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we need to score
|
|
||||||
bool is_scoring_boundary = ext_scorer_ != nullptr &&
|
|
||||||
ext_scorer_->is_scoring_boundary(prefix_to_score, c);
|
|
||||||
|
|
||||||
// language model scoring
|
// language model scoring
|
||||||
if (is_scoring_boundary) {
|
if (ext_scorer_->is_scoring_boundary(prefix_to_score, c)) {
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
std::vector<std::string> ngram;
|
std::vector<std::string> ngram;
|
||||||
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
||||||
@ -131,6 +128,7 @@ DecoderState::next(const double *probs,
|
|||||||
log_p += score;
|
log_p += score;
|
||||||
log_p += ext_scorer_->beta;
|
log_p += ext_scorer_->beta;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
prefix_new->log_prob_nb_cur =
|
prefix_new->log_prob_nb_cur =
|
||||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
|
|||||||
@ -46,12 +46,6 @@ size_t get_utf8_str_len(const std::string &str) {
|
|||||||
return str_len;
|
return str_len;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return weather a byte is a code point boundary (not a continuation byte).
|
|
||||||
bool byte_is_codepoint_boundary(unsigned char c) {
|
|
||||||
// only continuation bytes have their most significant bits set to 10
|
|
||||||
return (c & 0xC0) != 0x80;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> split_into_codepoints(const std::string &str) {
|
std::vector<std::string> split_into_codepoints(const std::string &str) {
|
||||||
std::vector<std::string> result;
|
std::vector<std::string> result;
|
||||||
std::string out_str;
|
std::string out_str;
|
||||||
|
|||||||
@ -89,8 +89,11 @@ std::vector<std::string> split_into_bytes(const std::string &str);
|
|||||||
void add_word_to_fst(const std::vector<int> &word,
|
void add_word_to_fst(const std::vector<int> &word,
|
||||||
fst::StdVectorFst *dictionary);
|
fst::StdVectorFst *dictionary);
|
||||||
|
|
||||||
// Return weather a byte is a code point boundary (not a continuation byte).
|
// Return whether a byte is a code point boundary (not a continuation byte).
|
||||||
bool byte_is_codepoint_boundary(unsigned char c);
|
inline bool byte_is_codepoint_boundary(unsigned char c) {
|
||||||
|
// only continuation bytes have their most significant bits set to 10
|
||||||
|
return (c & 0xC0) != 0x80;
|
||||||
|
}
|
||||||
|
|
||||||
// Add a word in string to dictionary
|
// Add a word in string to dictionary
|
||||||
bool add_word_to_dictionary(
|
bool add_word_to_dictionary(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user