Merge pull request #2194 from mozilla/const-fst
Switch to ConstFst from VectorFst and mmap trie file when reading
This commit is contained in:
commit
0ea580449c
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a5324f06b27c7b4ef88dd2c8bc3d05d4c718db219c2e007f59061a30c9ac7afa
|
||||
size 21627983
|
||||
oid sha256:1991108e83cf86830cad118ac9e061aadfc89b73a909a66fe1755f846def556e
|
||||
size 24480560
|
||||
|
||||
Binary file not shown.
@ -12,8 +12,6 @@
|
||||
#include "fst/fstlib.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
DecoderState*
|
||||
decoder_init(const Alphabet &alphabet,
|
||||
int class_dim,
|
||||
@ -41,7 +39,7 @@ decoder_init(const Alphabet &alphabet,
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
||||
root->set_dictionary(dict_ptr);
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root->set_matcher(matcher);
|
||||
}
|
||||
|
||||
|
||||
@ -162,13 +162,12 @@ void PathTrie::remove() {
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
|
||||
void PathTrie::set_dictionary(PathTrie::FstType* dictionary) {
|
||||
dictionary_ = dictionary;
|
||||
dictionary_state_ = dictionary->Start();
|
||||
has_dictionary_ = true;
|
||||
}
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
|
||||
void PathTrie::set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>> matcher) {
|
||||
matcher_ = matcher;
|
||||
}
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
*/
|
||||
class PathTrie {
|
||||
public:
|
||||
using FstType = fst::ConstFst<fst::StdArc>;
|
||||
|
||||
PathTrie();
|
||||
~PathTrie();
|
||||
|
||||
@ -33,9 +35,9 @@ public:
|
||||
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||
|
||||
// set dictionary for FST
|
||||
void set_dictionary(fst::StdVectorFst* dictionary);
|
||||
void set_dictionary(FstType* dictionary);
|
||||
|
||||
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
|
||||
void set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>>);
|
||||
|
||||
bool is_empty() { return ROOT_ == character; }
|
||||
|
||||
@ -61,10 +63,10 @@ private:
|
||||
std::vector<std::pair<int, PathTrie*>> children_;
|
||||
|
||||
// pointer to dictionary of FST
|
||||
fst::StdVectorFst* dictionary_;
|
||||
fst::StdVectorFst::StateId dictionary_state_;
|
||||
FstType* dictionary_;
|
||||
FstType::StateId dictionary_state_;
|
||||
// true if finding ars in FST
|
||||
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
|
||||
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
||||
};
|
||||
|
||||
#endif // PATH_TRIE_H
|
||||
|
||||
@ -26,8 +26,8 @@
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
static const int MAGIC = 'TRIE';
|
||||
static const int FILE_VERSION = 3;
|
||||
static const int32_t MAGIC = 'TRIE';
|
||||
static const int32_t FILE_VERSION = 4;
|
||||
|
||||
Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
@ -123,7 +123,9 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
||||
|
||||
if (!is_character_based_) {
|
||||
fst::FstReadOptions opt;
|
||||
dictionary.reset(fst::StdVectorFst::Read(fin, opt));
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = trie_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,6 +140,8 @@ void Scorer::save_dictionary(const std::string& path)
|
||||
fout.write(reinterpret_cast<const char*>(&is_character_based_), sizeof(is_character_based_));
|
||||
if (!is_character_based_) {
|
||||
fst::FstWriteOptions opt;
|
||||
opt.align = true;
|
||||
opt.source = path;
|
||||
dictionary->Write(fout, opt);
|
||||
}
|
||||
}
|
||||
@ -248,6 +252,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
||||
|
||||
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool add_space)
|
||||
{
|
||||
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
||||
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
||||
fst::StdVectorFst dictionary;
|
||||
// For each unigram convert to ints and put in trie
|
||||
for (const auto& word : vocabulary) {
|
||||
@ -262,18 +268,21 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
|
||||
* can greatly increase the size of the FST
|
||||
*/
|
||||
fst::RmEpsilon(&dictionary);
|
||||
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||
std::unique_ptr<fst::StdVectorFst> new_dict(new fst::StdVectorFst);
|
||||
|
||||
/* This makes the FST deterministic, meaning for any string input there's
|
||||
* only one possible state the FST could be in. It is assumed our
|
||||
* dictionary is deterministic when using it.
|
||||
* (lest we'd have to check for multiple transitions at each state)
|
||||
*/
|
||||
fst::Determinize(dictionary, new_dict);
|
||||
fst::Determinize(dictionary, new_dict.get());
|
||||
|
||||
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||
* memory usage of the dictionary
|
||||
*/
|
||||
fst::Minimize(new_dict);
|
||||
this->dictionary.reset(new_dict);
|
||||
fst::Minimize(new_dict.get());
|
||||
|
||||
// Now we convert the MutableFst to a ConstFst (Scorer::FstType) via its ctor
|
||||
std::unique_ptr<FstType> converted(new FstType(*new_dict));
|
||||
this->dictionary = std::move(converted);
|
||||
}
|
||||
|
||||
@ -40,6 +40,8 @@ public:
|
||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
*/
|
||||
class Scorer {
|
||||
using FstType = PathTrie::FstType;
|
||||
|
||||
public:
|
||||
Scorer(double alpha,
|
||||
double beta,
|
||||
@ -82,7 +84,7 @@ public:
|
||||
double beta;
|
||||
|
||||
// pointer to the dictionary of FST
|
||||
std::unique_ptr<fst::StdVectorFst> dictionary;
|
||||
std::unique_ptr<FstType> dictionary;
|
||||
|
||||
protected:
|
||||
// necessary setup: load language model, fill FST's dictionary
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
#include <string>
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
#include <string>
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user