Merge pull request #2194 from mozilla/const-fst
Switch to ConstFst from VectorFst and mmap trie file when reading
This commit is contained in:
commit
0ea580449c
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:a5324f06b27c7b4ef88dd2c8bc3d05d4c718db219c2e007f59061a30c9ac7afa
|
oid sha256:1991108e83cf86830cad118ac9e061aadfc89b73a909a66fe1755f846def556e
|
||||||
size 21627983
|
size 24480560
|
||||||
|
|||||||
Binary file not shown.
@ -12,8 +12,6 @@
|
|||||||
#include "fst/fstlib.h"
|
#include "fst/fstlib.h"
|
||||||
#include "path_trie.h"
|
#include "path_trie.h"
|
||||||
|
|
||||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
|
||||||
|
|
||||||
DecoderState*
|
DecoderState*
|
||||||
decoder_init(const Alphabet &alphabet,
|
decoder_init(const Alphabet &alphabet,
|
||||||
int class_dim,
|
int class_dim,
|
||||||
@ -41,7 +39,7 @@ decoder_init(const Alphabet &alphabet,
|
|||||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
||||||
root->set_dictionary(dict_ptr);
|
root->set_dictionary(dict_ptr);
|
||||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
root->set_matcher(matcher);
|
root->set_matcher(matcher);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -162,13 +162,12 @@ void PathTrie::remove() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
|
void PathTrie::set_dictionary(PathTrie::FstType* dictionary) {
|
||||||
dictionary_ = dictionary;
|
dictionary_ = dictionary;
|
||||||
dictionary_state_ = dictionary->Start();
|
dictionary_state_ = dictionary->Start();
|
||||||
has_dictionary_ = true;
|
has_dictionary_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
void PathTrie::set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>> matcher) {
|
||||||
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
|
|
||||||
matcher_ = matcher;
|
matcher_ = matcher;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,6 +14,8 @@
|
|||||||
*/
|
*/
|
||||||
class PathTrie {
|
class PathTrie {
|
||||||
public:
|
public:
|
||||||
|
using FstType = fst::ConstFst<fst::StdArc>;
|
||||||
|
|
||||||
PathTrie();
|
PathTrie();
|
||||||
~PathTrie();
|
~PathTrie();
|
||||||
|
|
||||||
@ -33,9 +35,9 @@ public:
|
|||||||
void iterate_to_vec(std::vector<PathTrie*>& output);
|
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||||
|
|
||||||
// set dictionary for FST
|
// set dictionary for FST
|
||||||
void set_dictionary(fst::StdVectorFst* dictionary);
|
void set_dictionary(FstType* dictionary);
|
||||||
|
|
||||||
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
|
void set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>>);
|
||||||
|
|
||||||
bool is_empty() { return ROOT_ == character; }
|
bool is_empty() { return ROOT_ == character; }
|
||||||
|
|
||||||
@ -61,10 +63,10 @@ private:
|
|||||||
std::vector<std::pair<int, PathTrie*>> children_;
|
std::vector<std::pair<int, PathTrie*>> children_;
|
||||||
|
|
||||||
// pointer to dictionary of FST
|
// pointer to dictionary of FST
|
||||||
fst::StdVectorFst* dictionary_;
|
FstType* dictionary_;
|
||||||
fst::StdVectorFst::StateId dictionary_state_;
|
FstType::StateId dictionary_state_;
|
||||||
// true if finding ars in FST
|
// true if finding ars in FST
|
||||||
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
|
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // PATH_TRIE_H
|
#endif // PATH_TRIE_H
|
||||||
|
|||||||
@ -26,8 +26,8 @@
|
|||||||
|
|
||||||
using namespace lm::ngram;
|
using namespace lm::ngram;
|
||||||
|
|
||||||
static const int MAGIC = 'TRIE';
|
static const int32_t MAGIC = 'TRIE';
|
||||||
static const int FILE_VERSION = 3;
|
static const int32_t FILE_VERSION = 4;
|
||||||
|
|
||||||
Scorer::Scorer(double alpha,
|
Scorer::Scorer(double alpha,
|
||||||
double beta,
|
double beta,
|
||||||
@ -123,7 +123,9 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
|||||||
|
|
||||||
if (!is_character_based_) {
|
if (!is_character_based_) {
|
||||||
fst::FstReadOptions opt;
|
fst::FstReadOptions opt;
|
||||||
dictionary.reset(fst::StdVectorFst::Read(fin, opt));
|
opt.mode = fst::FstReadOptions::MAP;
|
||||||
|
opt.source = trie_path;
|
||||||
|
dictionary.reset(FstType::Read(fin, opt));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,6 +140,8 @@ void Scorer::save_dictionary(const std::string& path)
|
|||||||
fout.write(reinterpret_cast<const char*>(&is_character_based_), sizeof(is_character_based_));
|
fout.write(reinterpret_cast<const char*>(&is_character_based_), sizeof(is_character_based_));
|
||||||
if (!is_character_based_) {
|
if (!is_character_based_) {
|
||||||
fst::FstWriteOptions opt;
|
fst::FstWriteOptions opt;
|
||||||
|
opt.align = true;
|
||||||
|
opt.source = path;
|
||||||
dictionary->Write(fout, opt);
|
dictionary->Write(fout, opt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,6 +252,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
|
|
||||||
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool add_space)
|
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool add_space)
|
||||||
{
|
{
|
||||||
|
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
||||||
|
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
||||||
fst::StdVectorFst dictionary;
|
fst::StdVectorFst dictionary;
|
||||||
// For each unigram convert to ints and put in trie
|
// For each unigram convert to ints and put in trie
|
||||||
for (const auto& word : vocabulary) {
|
for (const auto& word : vocabulary) {
|
||||||
@ -262,18 +268,21 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
|
|||||||
* can greatly increase the size of the FST
|
* can greatly increase the size of the FST
|
||||||
*/
|
*/
|
||||||
fst::RmEpsilon(&dictionary);
|
fst::RmEpsilon(&dictionary);
|
||||||
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
std::unique_ptr<fst::StdVectorFst> new_dict(new fst::StdVectorFst);
|
||||||
|
|
||||||
/* This makes the FST deterministic, meaning for any string input there's
|
/* This makes the FST deterministic, meaning for any string input there's
|
||||||
* only one possible state the FST could be in. It is assumed our
|
* only one possible state the FST could be in. It is assumed our
|
||||||
* dictionary is deterministic when using it.
|
* dictionary is deterministic when using it.
|
||||||
* (lest we'd have to check for multiple transitions at each state)
|
* (lest we'd have to check for multiple transitions at each state)
|
||||||
*/
|
*/
|
||||||
fst::Determinize(dictionary, new_dict);
|
fst::Determinize(dictionary, new_dict.get());
|
||||||
|
|
||||||
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||||
* memory usage of the dictionary
|
* memory usage of the dictionary
|
||||||
*/
|
*/
|
||||||
fst::Minimize(new_dict);
|
fst::Minimize(new_dict.get());
|
||||||
this->dictionary.reset(new_dict);
|
|
||||||
|
// Now we convert the MutableFst to a ConstFst (Scorer::FstType) via its ctor
|
||||||
|
std::unique_ptr<FstType> converted(new FstType(*new_dict));
|
||||||
|
this->dictionary = std::move(converted);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,6 +40,8 @@ public:
|
|||||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||||
*/
|
*/
|
||||||
class Scorer {
|
class Scorer {
|
||||||
|
using FstType = PathTrie::FstType;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Scorer(double alpha,
|
Scorer(double alpha,
|
||||||
double beta,
|
double beta,
|
||||||
@ -82,7 +84,7 @@ public:
|
|||||||
double beta;
|
double beta;
|
||||||
|
|
||||||
// pointer to the dictionary of FST
|
// pointer to the dictionary of FST
|
||||||
std::unique_ptr<fst::StdVectorFst> dictionary;
|
std::unique_ptr<FstType> dictionary;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// necessary setup: load language model, fill FST's dictionary
|
// necessary setup: load language model, fill FST's dictionary
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "ctcdecode/scorer.h"
|
#include "ctcdecode/scorer.h"
|
||||||
#include "fst/fstlib.h"
|
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "ctcdecode/scorer.h"
|
#include "ctcdecode/scorer.h"
|
||||||
#include "fst/fstlib.h"
|
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user