Use Alphabet to compute string values in get_prev_*
This commit is contained in:
parent
0f71bd2493
commit
5fa1839a7f
@ -124,7 +124,8 @@ void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timestep
|
|||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
||||||
std::vector<int>& timesteps)
|
std::vector<int>& timesteps,
|
||||||
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
if (character == ROOT_) {
|
if (character == ROOT_) {
|
||||||
@ -132,24 +133,23 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
|||||||
}
|
}
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
//FIXME: use Alphabet instead of hardcoding +1 here
|
if (!byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
||||||
if (!byte_is_codepoint_boundary(character + 1)) {
|
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
||||||
stop = parent->get_prev_grapheme(output, timesteps);
|
|
||||||
}
|
}
|
||||||
output.push_back(character);
|
output.push_back(character);
|
||||||
timesteps.push_back(timestep);
|
timesteps.push_back(timestep);
|
||||||
return stop;
|
return stop;
|
||||||
}
|
}
|
||||||
|
|
||||||
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte)
|
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
||||||
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
//FIXME: use Alphabet instead of hardcoding +1 here
|
if (byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
||||||
if (byte_is_codepoint_boundary(character + 1)) {
|
|
||||||
*first_byte = (unsigned char)character + 1;
|
*first_byte = (unsigned char)character + 1;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (parent != nullptr && parent->character != ROOT_) {
|
if (parent != nullptr && parent->character != ROOT_) {
|
||||||
return 1 + parent->distance_to_codepoint_boundary(first_byte);
|
return 1 + parent->distance_to_codepoint_boundary(first_byte, alphabet);
|
||||||
}
|
}
|
||||||
assert(false); // unreachable
|
assert(false); // unreachable
|
||||||
return 0;
|
return 0;
|
||||||
@ -157,16 +157,16 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte)
|
|||||||
|
|
||||||
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
|
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<int>& timesteps,
|
||||||
int space_id)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
if (character == space_id || character == ROOT_) {
|
if (character == alphabet.GetSpaceLabel() || character == ROOT_) {
|
||||||
return stop;
|
return stop;
|
||||||
}
|
}
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (parent != nullptr) {
|
if (parent != nullptr) {
|
||||||
stop = parent->get_prev_word(output, timesteps, space_id);
|
stop = parent->get_prev_word(output, timesteps, alphabet);
|
||||||
}
|
}
|
||||||
output.push_back(character);
|
output.push_back(character);
|
||||||
timesteps.push_back(timestep);
|
timesteps.push_back(timestep);
|
||||||
|
@ -8,10 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "fst/fstlib.h"
|
#include "fst/fstlib.h"
|
||||||
|
|
||||||
#ifdef DEBUG
|
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
#endif
|
|
||||||
|
|
||||||
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||||
* finite-state transducer for spelling correction.
|
* finite-state transducer for spelling correction.
|
||||||
@ -31,15 +28,16 @@ public:
|
|||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last grapheme to current node
|
// get the prefix data in correct time order from beginning of last grapheme to current node
|
||||||
PathTrie* get_prev_grapheme(std::vector<int>& output,
|
PathTrie* get_prev_grapheme(std::vector<int>& output,
|
||||||
std::vector<int>& timesteps);
|
std::vector<int>& timesteps,
|
||||||
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||||
int distance_to_codepoint_boundary(unsigned char *first_byte);
|
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last word to current node
|
// get the prefix data in correct time order from beginning of last word to current node
|
||||||
PathTrie* get_prev_word(std::vector<int>& output,
|
PathTrie* get_prev_word(std::vector<int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<int>& timesteps,
|
||||||
int space_id);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// update log probs
|
// update log probs
|
||||||
void iterate_to_vec(std::vector<PathTrie*>& output);
|
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||||
@ -80,7 +78,6 @@ private:
|
|||||||
// pointer to dictionary of FST
|
// pointer to dictionary of FST
|
||||||
std::shared_ptr<FstType> dictionary_;
|
std::shared_ptr<FstType> dictionary_;
|
||||||
FstType::StateId dictionary_state_;
|
FstType::StateId dictionary_state_;
|
||||||
// true if finding ars in FST
|
|
||||||
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
unsigned char first_byte;
|
unsigned char first_byte;
|
||||||
int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte);
|
int distance_to_boundary = prefix->distance_to_codepoint_boundary(&first_byte, alphabet_);
|
||||||
int needed_bytes;
|
int needed_bytes;
|
||||||
if ((first_byte >> 3) == 0x1E) {
|
if ((first_byte >> 3) == 0x1E) {
|
||||||
needed_bytes = 4;
|
needed_bytes = 4;
|
||||||
@ -319,9 +319,9 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
std::vector<int> prefix_steps;
|
std::vector<int> prefix_steps;
|
||||||
|
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps);
|
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
||||||
} else {
|
} else {
|
||||||
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, SPACE_ID_);
|
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_);
|
||||||
}
|
}
|
||||||
current_node = new_node->parent;
|
current_node = new_node->parent;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user