Merge pull request #2194 from mozilla/const-fst

Switch to ConstFst from VectorFst and mmap trie file when reading
This commit is contained in:
Reuben Morais 2019-06-22 09:41:15 -03:00 committed by GitHub
commit 0ea580449c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 31 additions and 23 deletions

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a5324f06b27c7b4ef88dd2c8bc3d05d4c718db219c2e007f59061a30c9ac7afa
size 21627983
oid sha256:1991108e83cf86830cad118ac9e061aadfc89b73a909a66fe1755f846def556e
size 24480560

Binary file not shown.

View File

@ -12,8 +12,6 @@
#include "fst/fstlib.h"
#include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
DecoderState*
decoder_init(const Alphabet &alphabet,
int class_dim,
@ -41,7 +39,7 @@ decoder_init(const Alphabet &alphabet,
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto dict_ptr = ext_scorer->dictionary->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}

View File

@ -162,13 +162,12 @@ void PathTrie::remove() {
}
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
void PathTrie::set_dictionary(PathTrie::FstType* dictionary) {
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();
has_dictionary_ = true;
}
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
void PathTrie::set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>> matcher) {
matcher_ = matcher;
}

View File

@ -14,6 +14,8 @@
*/
class PathTrie {
public:
using FstType = fst::ConstFst<fst::StdArc>;
PathTrie();
~PathTrie();
@ -33,9 +35,9 @@ public:
void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary);
void set_dictionary(FstType* dictionary);
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
void set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>>);
bool is_empty() { return ROOT_ == character; }
@ -61,10 +63,10 @@ private:
std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST
fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId dictionary_state_;
FstType* dictionary_;
FstType::StateId dictionary_state_;
// true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
};
#endif // PATH_TRIE_H

View File

@ -26,8 +26,8 @@
using namespace lm::ngram;
static const int MAGIC = 'TRIE';
static const int FILE_VERSION = 3;
static const int32_t MAGIC = 'TRIE';
static const int32_t FILE_VERSION = 4;
Scorer::Scorer(double alpha,
double beta,
@ -123,7 +123,9 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
if (!is_character_based_) {
fst::FstReadOptions opt;
dictionary.reset(fst::StdVectorFst::Read(fin, opt));
opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path;
dictionary.reset(FstType::Read(fin, opt));
}
}
@ -138,6 +140,8 @@ void Scorer::save_dictionary(const std::string& path)
fout.write(reinterpret_cast<const char*>(&is_character_based_), sizeof(is_character_based_));
if (!is_character_based_) {
fst::FstWriteOptions opt;
opt.align = true;
opt.source = path;
dictionary->Write(fout, opt);
}
}
@ -248,6 +252,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool add_space)
{
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
// and then we convert to a ConstFst for the decoder and for storing on disk.
fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie
for (const auto& word : vocabulary) {
@ -262,18 +268,21 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
* can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
std::unique_ptr<fst::StdVectorFst> new_dict(new fst::StdVectorFst);
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict);
fst::Determinize(dictionary, new_dict.get());
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst::Minimize(new_dict);
this->dictionary.reset(new_dict);
fst::Minimize(new_dict.get());
// Now we convert the MutableFst to a ConstFst (Scorer::FstType) via its ctor
std::unique_ptr<FstType> converted(new FstType(*new_dict));
this->dictionary = std::move(converted);
}

View File

@ -40,6 +40,8 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
using FstType = PathTrie::FstType;
public:
Scorer(double alpha,
double beta,
@ -82,7 +84,7 @@ public:
double beta;
// pointer to the dictionary of FST
std::unique_ptr<fst::StdVectorFst> dictionary;
std::unique_ptr<FstType> dictionary;
protected:
// necessary setup: load language model, fill FST's dictionary

View File

@ -3,7 +3,6 @@
#include <string>
#include "ctcdecode/scorer.h"
#include "fst/fstlib.h"
#include "alphabet.h"
using namespace std;

View File

@ -3,7 +3,6 @@
#include <string>
#include "ctcdecode/scorer.h"
#include "fst/fstlib.h"
#include "alphabet.h"
using namespace std;