Merge pull request #2194 from mozilla/const-fst

Switch to ConstFst from VectorFst and mmap trie file when reading
This commit is contained in:
Reuben Morais 2019-06-22 09:41:15 -03:00 committed by GitHub
commit 0ea580449c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 23 deletions

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:a5324f06b27c7b4ef88dd2c8bc3d05d4c718db219c2e007f59061a30c9ac7afa oid sha256:1991108e83cf86830cad118ac9e061aadfc89b73a909a66fe1755f846def556e
size 21627983 size 24480560

Binary file not shown.

View File

@ -12,8 +12,6 @@
#include "fst/fstlib.h" #include "fst/fstlib.h"
#include "path_trie.h" #include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
DecoderState* DecoderState*
decoder_init(const Alphabet &alphabet, decoder_init(const Alphabet &alphabet,
int class_dim, int class_dim,
@ -41,7 +39,7 @@ decoder_init(const Alphabet &alphabet,
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto dict_ptr = ext_scorer->dictionary->Copy(true); auto dict_ptr = ext_scorer->dictionary->Copy(true);
root->set_dictionary(dict_ptr); root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher); root->set_matcher(matcher);
} }

View File

@ -162,13 +162,12 @@ void PathTrie::remove() {
} }
} }
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { void PathTrie::set_dictionary(PathTrie::FstType* dictionary) {
dictionary_ = dictionary; dictionary_ = dictionary;
dictionary_state_ = dictionary->Start(); dictionary_state_ = dictionary->Start();
has_dictionary_ = true; has_dictionary_ = true;
} }
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; void PathTrie::set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>> matcher) {
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
matcher_ = matcher; matcher_ = matcher;
} }

View File

@ -14,6 +14,8 @@
*/ */
class PathTrie { class PathTrie {
public: public:
using FstType = fst::ConstFst<fst::StdArc>;
PathTrie(); PathTrie();
~PathTrie(); ~PathTrie();
@ -33,9 +35,9 @@ public:
void iterate_to_vec(std::vector<PathTrie*>& output); void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST // set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary); void set_dictionary(FstType* dictionary);
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>); void set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>>);
bool is_empty() { return ROOT_ == character; } bool is_empty() { return ROOT_ == character; }
@ -61,10 +63,10 @@ private:
std::vector<std::pair<int, PathTrie*>> children_; std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST // pointer to dictionary of FST
fst::StdVectorFst* dictionary_; FstType* dictionary_;
fst::StdVectorFst::StateId dictionary_state_; FstType::StateId dictionary_state_;
// true if finding ars in FST // true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_; std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
}; };
#endif // PATH_TRIE_H #endif // PATH_TRIE_H

View File

@ -26,8 +26,8 @@
using namespace lm::ngram; using namespace lm::ngram;
static const int MAGIC = 'TRIE'; static const int32_t MAGIC = 'TRIE';
static const int FILE_VERSION = 3; static const int32_t FILE_VERSION = 4;
Scorer::Scorer(double alpha, Scorer::Scorer(double alpha,
double beta, double beta,
@ -123,7 +123,9 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
if (!is_character_based_) { if (!is_character_based_) {
fst::FstReadOptions opt; fst::FstReadOptions opt;
dictionary.reset(fst::StdVectorFst::Read(fin, opt)); opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path;
dictionary.reset(FstType::Read(fin, opt));
} }
} }
@ -138,6 +140,8 @@ void Scorer::save_dictionary(const std::string& path)
fout.write(reinterpret_cast<const char*>(&is_character_based_), sizeof(is_character_based_)); fout.write(reinterpret_cast<const char*>(&is_character_based_), sizeof(is_character_based_));
if (!is_character_based_) { if (!is_character_based_) {
fst::FstWriteOptions opt; fst::FstWriteOptions opt;
opt.align = true;
opt.source = path;
dictionary->Write(fout, opt); dictionary->Write(fout, opt);
} }
} }
@ -248,6 +252,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool add_space) void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool add_space)
{ {
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
// and then we convert to a ConstFst for the decoder and for storing on disk.
fst::StdVectorFst dictionary; fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie // For each unigram convert to ints and put in trie
for (const auto& word : vocabulary) { for (const auto& word : vocabulary) {
@ -262,18 +268,21 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
* can greatly increase the size of the FST * can greatly increase the size of the FST
*/ */
fst::RmEpsilon(&dictionary); fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst; std::unique_ptr<fst::StdVectorFst> new_dict(new fst::StdVectorFst);
/* This makes the FST deterministic, meaning for any string input there's /* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our * only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it. * dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state) * (lest we'd have to check for multiple transitions at each state)
*/ */
fst::Determinize(dictionary, new_dict); fst::Determinize(dictionary, new_dict.get());
/* Finds the simplest equivalent fst. This is unnecessary but decreases /* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary * memory usage of the dictionary
*/ */
fst::Minimize(new_dict); fst::Minimize(new_dict.get());
this->dictionary.reset(new_dict);
// Now we convert the MutableFst to a ConstFst (Scorer::FstType) via its ctor
std::unique_ptr<FstType> converted(new FstType(*new_dict));
this->dictionary = std::move(converted);
} }

View File

@ -40,6 +40,8 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/ */
class Scorer { class Scorer {
using FstType = PathTrie::FstType;
public: public:
Scorer(double alpha, Scorer(double alpha,
double beta, double beta,
@ -82,7 +84,7 @@ public:
double beta; double beta;
// pointer to the dictionary of FST // pointer to the dictionary of FST
std::unique_ptr<fst::StdVectorFst> dictionary; std::unique_ptr<FstType> dictionary;
protected: protected:
// necessary setup: load language model, fill FST's dictionary // necessary setup: load language model, fill FST's dictionary

View File

@ -3,7 +3,6 @@
#include <string> #include <string>
#include "ctcdecode/scorer.h" #include "ctcdecode/scorer.h"
#include "fst/fstlib.h"
#include "alphabet.h" #include "alphabet.h"
using namespace std; using namespace std;

View File

@ -3,7 +3,6 @@
#include <string> #include <string>
#include "ctcdecode/scorer.h" #include "ctcdecode/scorer.h"
#include "fst/fstlib.h"
#include "alphabet.h" #include "alphabet.h"
using namespace std; using namespace std;