Score prefix as soon as a grapheme is formed rather than 1 byte later

This commit is contained in:
Reuben Morais 2019-11-08 16:17:28 +01:00
parent f4cdd988df
commit c1b1a59423
5 changed files with 55 additions and 9 deletions

View File

@ -109,11 +109,23 @@ DecoderState::next(const double *probs,
log_p = log_prob_c + prefix->score;
}
// skip scoring the space in word based LMs
PathTrie* prefix_to_score;
if (ext_scorer_->is_utf8_mode()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
// check if we need to score
bool is_scoring_boundary = ext_scorer_ != nullptr &&
ext_scorer_->is_scoring_boundary(prefix_to_score, c);
// language model scoring
if (prefix->character != -1 && ext_scorer_ != nullptr && ext_scorer_->is_scoring_boundary(c)) {
if (is_scoring_boundary) {
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer_->make_ngram(prefix);
ngram = ext_scorer_->make_ngram(prefix_to_score);
bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
log_p += score;

View File

@ -133,7 +133,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
//FIXME: use Alphabet instead of hardcoding +1 here
if (!byte_is_codepoint_boundary(character+1)) {
if (!byte_is_codepoint_boundary(character + 1)) {
stop = parent->get_prev_grapheme(output, timesteps);
}
output.push_back(character);
@ -141,6 +141,20 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
return stop;
}
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte)
{
//FIXME: use Alphabet instead of hardcoding +1 here
if (byte_is_codepoint_boundary(character + 1)) {
*first_byte = (unsigned char)character + 1;
return 1;
}
if (parent != nullptr && parent->character != ROOT_) {
return 1 + parent->distance_to_codepoint_boundary(first_byte);
}
assert(false); // unreachable
return 0;
}
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
std::vector<int>& timesteps,
int space_id)

View File

@ -33,6 +33,9 @@ public:
PathTrie* get_prev_grapheme(std::vector<int>& output,
std::vector<int>& timesteps);
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
int distance_to_codepoint_boundary(unsigned char *first_byte);
// get the prefix data in correct time order from beginning of last word to current node
PathTrie* get_prev_word(std::vector<int>& output,
std::vector<int>& timesteps,

View File

@ -149,13 +149,30 @@ void Scorer::save_dictionary(const std::string& path)
dictionary->Write(fout, opt);
}
bool Scorer::is_scoring_boundary(size_t label)
bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
{
if (is_utf8_mode()) {
unsigned char byte_val = alphabet_.StringFromLabel(label)[0];
return byte_is_codepoint_boundary(byte_val);
if (prefix->character == -1) {
return false;
}
unsigned char first_byte;
int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte);
int needed_bytes;
if ((first_byte >> 3) == 0x1E) {
needed_bytes = 4;
} else if ((first_byte >> 4) == 0x0E) {
needed_bytes = 3;
} else if ((first_byte >> 5) == 0x06) {
needed_bytes = 2;
} else if ((first_byte >> 7) == 0x00) {
needed_bytes = 1;
} else {
assert(false); // invalid byte sequence. should be unreachable, disallowed by vocabulary/trie
return false;
}
return distance_to_boundary == needed_bytes;
} else {
return label == SPACE_ID_;
return new_label == SPACE_ID_;
}
}

View File

@ -92,8 +92,8 @@ public:
// save dictionary in file
void save_dictionary(const std::string &path);
// return weather this label represents a boundary where beam scoring should happen
bool is_scoring_boundary(size_t label);
// return weather this step represents a boundary where beam scoring should happen
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
// language model weight
double alpha = 0.;