Score prefix as soon as a grapheme is formed rather than 1 byte later
This commit is contained in:
parent
f4cdd988df
commit
c1b1a59423
@ -109,11 +109,23 @@ DecoderState::next(const double *probs,
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// skip scoring the space in word based LMs
|
||||
PathTrie* prefix_to_score;
|
||||
if (ext_scorer_->is_utf8_mode()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
// check if we need to score
|
||||
bool is_scoring_boundary = ext_scorer_ != nullptr &&
|
||||
ext_scorer_->is_scoring_boundary(prefix_to_score, c);
|
||||
|
||||
// language model scoring
|
||||
if (prefix->character != -1 && ext_scorer_ != nullptr && ext_scorer_->is_scoring_boundary(c)) {
|
||||
if (is_scoring_boundary) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer_->make_ngram(prefix);
|
||||
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
||||
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
||||
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
|
@ -133,7 +133,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
||||
// Recursive call: recurse back until stop condition, then append data in
|
||||
// correct order as we walk back down the stack in the lines below.
|
||||
//FIXME: use Alphabet instead of hardcoding +1 here
|
||||
if (!byte_is_codepoint_boundary(character+1)) {
|
||||
if (!byte_is_codepoint_boundary(character + 1)) {
|
||||
stop = parent->get_prev_grapheme(output, timesteps);
|
||||
}
|
||||
output.push_back(character);
|
||||
@ -141,6 +141,20 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
||||
return stop;
|
||||
}
|
||||
|
||||
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte)
|
||||
{
|
||||
//FIXME: use Alphabet instead of hardcoding +1 here
|
||||
if (byte_is_codepoint_boundary(character + 1)) {
|
||||
*first_byte = (unsigned char)character + 1;
|
||||
return 1;
|
||||
}
|
||||
if (parent != nullptr && parent->character != ROOT_) {
|
||||
return 1 + parent->distance_to_codepoint_boundary(first_byte);
|
||||
}
|
||||
assert(false); // unreachable
|
||||
return 0;
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
|
||||
std::vector<int>& timesteps,
|
||||
int space_id)
|
||||
|
@ -33,6 +33,9 @@ public:
|
||||
PathTrie* get_prev_grapheme(std::vector<int>& output,
|
||||
std::vector<int>& timesteps);
|
||||
|
||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||
int distance_to_codepoint_boundary(unsigned char *first_byte);
|
||||
|
||||
// get the prefix data in correct time order from beginning of last word to current node
|
||||
PathTrie* get_prev_word(std::vector<int>& output,
|
||||
std::vector<int>& timesteps,
|
||||
|
@ -149,13 +149,30 @@ void Scorer::save_dictionary(const std::string& path)
|
||||
dictionary->Write(fout, opt);
|
||||
}
|
||||
|
||||
bool Scorer::is_scoring_boundary(size_t label)
|
||||
bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
|
||||
{
|
||||
if (is_utf8_mode()) {
|
||||
unsigned char byte_val = alphabet_.StringFromLabel(label)[0];
|
||||
return byte_is_codepoint_boundary(byte_val);
|
||||
if (prefix->character == -1) {
|
||||
return false;
|
||||
}
|
||||
unsigned char first_byte;
|
||||
int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte);
|
||||
int needed_bytes;
|
||||
if ((first_byte >> 3) == 0x1E) {
|
||||
needed_bytes = 4;
|
||||
} else if ((first_byte >> 4) == 0x0E) {
|
||||
needed_bytes = 3;
|
||||
} else if ((first_byte >> 5) == 0x06) {
|
||||
needed_bytes = 2;
|
||||
} else if ((first_byte >> 7) == 0x00) {
|
||||
needed_bytes = 1;
|
||||
} else {
|
||||
assert(false); // invalid byte sequence. should be unreachable, disallowed by vocabulary/trie
|
||||
return false;
|
||||
}
|
||||
return distance_to_boundary == needed_bytes;
|
||||
} else {
|
||||
return label == SPACE_ID_;
|
||||
return new_label == SPACE_ID_;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,8 +92,8 @@ public:
|
||||
// save dictionary in file
|
||||
void save_dictionary(const std::string &path);
|
||||
|
||||
// return weather this label represents a boundary where beam scoring should happen
|
||||
bool is_scoring_boundary(size_t label);
|
||||
// return weather this step represents a boundary where beam scoring should happen
|
||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||
|
||||
// language model weight
|
||||
double alpha = 0.;
|
||||
|
Loading…
x
Reference in New Issue
Block a user