Use Alphabet to compute string values in get_prev_*

This commit is contained in:
Reuben Morais 2020-04-13 10:13:49 +02:00
parent 0f71bd2493
commit 5fa1839a7f
3 changed files with 18 additions and 21 deletions

View File

@ -124,7 +124,8 @@ void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timestep
}
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
std::vector<int>& timesteps)
std::vector<int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
if (character == ROOT_) {
@ -132,24 +133,23 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
}
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
//FIXME: use Alphabet instead of hardcoding +1 here
if (!byte_is_codepoint_boundary(character + 1)) {
stop = parent->get_prev_grapheme(output, timesteps);
if (!byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
}
output.push_back(character);
timesteps.push_back(timestep);
return stop;
}
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte)
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
const Alphabet& alphabet)
{
//FIXME: use Alphabet instead of hardcoding +1 here
if (byte_is_codepoint_boundary(character + 1)) {
if (byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
*first_byte = (unsigned char)character + 1;
return 1;
}
if (parent != nullptr && parent->character != ROOT_) {
return 1 + parent->distance_to_codepoint_boundary(first_byte);
return 1 + parent->distance_to_codepoint_boundary(first_byte, alphabet);
}
assert(false); // unreachable
return 0;
@ -157,16 +157,16 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte)
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
std::vector<int>& timesteps,
int space_id)
const Alphabet& alphabet)
{
PathTrie* stop = this;
if (character == space_id || character == ROOT_) {
if (character == alphabet.GetSpaceLabel() || character == ROOT_) {
return stop;
}
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (parent != nullptr) {
stop = parent->get_prev_word(output, timesteps, space_id);
stop = parent->get_prev_word(output, timesteps, alphabet);
}
output.push_back(character);
timesteps.push_back(timestep);

View File

@ -8,10 +8,7 @@
#include <vector>
#include "fst/fstlib.h"
#ifdef DEBUG
#include "alphabet.h"
#endif
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
@ -31,15 +28,16 @@ public:
// get the prefix data in correct time order from beginning of last grapheme to current node
PathTrie* get_prev_grapheme(std::vector<int>& output,
std::vector<int>& timesteps);
std::vector<int>& timesteps,
const Alphabet& alphabet);
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
int distance_to_codepoint_boundary(unsigned char *first_byte);
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
// get the prefix data in correct time order from beginning of last word to current node
PathTrie* get_prev_word(std::vector<int>& output,
std::vector<int>& timesteps,
int space_id);
const Alphabet& alphabet);
// update log probs
void iterate_to_vec(std::vector<PathTrie*>& output);
@ -80,7 +78,6 @@ private:
// pointer to dictionary of FST
std::shared_ptr<FstType> dictionary_;
FstType::StateId dictionary_state_;
// true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
};

View File

@ -173,7 +173,7 @@ bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
return false;
}
unsigned char first_byte;
int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte);
int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte, alphabet_);
int needed_bytes;
if ((first_byte >> 3) == 0x1E) {
needed_bytes = 4;
@ -319,9 +319,9 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
std::vector<int> prefix_steps;
if (is_utf8_mode_) {
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps);
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
} else {
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, SPACE_ID_);
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_);
}
current_node = new_node->parent;