diff --git a/data/lm/trie b/data/lm/trie index a937a226..87dd3a32 100644 --- a/data/lm/trie +++ b/data/lm/trie @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a5324f06b27c7b4ef88dd2c8bc3d05d4c718db219c2e007f59061a30c9ac7afa -size 21627983 +oid sha256:1991108e83cf86830cad118ac9e061aadfc89b73a909a66fe1755f846def556e +size 24480560 diff --git a/data/smoke_test/vocab.trie b/data/smoke_test/vocab.trie index 3a5d637f..7237e739 100644 Binary files a/data/smoke_test/vocab.trie and b/data/smoke_test/vocab.trie differ diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 1b7f0e98..e68d778f 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -12,8 +12,6 @@ #include "fst/fstlib.h" #include "path_trie.h" -using FSTMATCH = fst::SortedMatcher; - DecoderState* decoder_init(const Alphabet &alphabet, int class_dim, @@ -41,7 +39,7 @@ decoder_init(const Alphabet &alphabet, if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { auto dict_ptr = ext_scorer->dictionary->Copy(true); root->set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + auto matcher = std::make_shared>(*dict_ptr, fst::MATCH_INPUT); root->set_matcher(matcher); } diff --git a/native_client/ctcdecode/path_trie.cpp b/native_client/ctcdecode/path_trie.cpp index 7fe3fbde..560361ec 100644 --- a/native_client/ctcdecode/path_trie.cpp +++ b/native_client/ctcdecode/path_trie.cpp @@ -162,13 +162,12 @@ void PathTrie::remove() { } } -void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { +void PathTrie::set_dictionary(PathTrie::FstType* dictionary) { dictionary_ = dictionary; dictionary_state_ = dictionary->Start(); has_dictionary_ = true; } -using FSTMATCH = fst::SortedMatcher; -void PathTrie::set_matcher(std::shared_ptr matcher) { +void PathTrie::set_matcher(std::shared_ptr> matcher) { matcher_ = matcher; } diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h index baa27dbe..10a1b687 100644 --- a/native_client/ctcdecode/path_trie.h +++ b/native_client/ctcdecode/path_trie.h @@ -14,6 +14,8 @@ */ class PathTrie { public: + using FstType = fst::ConstFst; + PathTrie(); ~PathTrie(); @@ -33,9 +35,9 @@ public: void iterate_to_vec(std::vector& output); // set dictionary for FST - void set_dictionary(fst::StdVectorFst* dictionary); + void set_dictionary(FstType* dictionary); - void set_matcher(std::shared_ptr>); + void set_matcher(std::shared_ptr>); bool is_empty() { return ROOT_ == character; } @@ -61,10 +63,10 @@ private: std::vector> children_; // pointer to dictionary of FST - fst::StdVectorFst* dictionary_; - fst::StdVectorFst::StateId dictionary_state_; + FstType* dictionary_; + FstType::StateId dictionary_state_; // true if finding ars in FST - std::shared_ptr> matcher_; + std::shared_ptr> matcher_; }; #endif // PATH_TRIE_H diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 39b9caf4..92479300 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -26,8 +26,8 @@ using namespace lm::ngram; -static const int MAGIC = 'TRIE'; -static const int FILE_VERSION = 3; +static const int32_t MAGIC = 'TRIE'; +static const int32_t FILE_VERSION = 4; Scorer::Scorer(double alpha, double beta, @@ -123,7 +123,9 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path) if (!is_character_based_) { fst::FstReadOptions opt; - dictionary.reset(fst::StdVectorFst::Read(fin, opt)); + opt.mode = fst::FstReadOptions::MAP; + opt.source = trie_path; + dictionary.reset(FstType::Read(fin, opt)); } } @@ -138,6 +140,8 @@ void Scorer::save_dictionary(const std::string& path) fout.write(reinterpret_cast(&is_character_based_), sizeof(is_character_based_)); if (!is_character_based_) { fst::FstWriteOptions opt; + opt.align = true; + opt.source = path; dictionary->Write(fout, opt); } } @@ -248,6 +252,8 @@ std::vector Scorer::make_ngram(PathTrie* prefix) void Scorer::fill_dictionary(const std::vector& vocabulary, bool add_space) { + // ConstFst is immutable, so we need to use a MutableFst to create the trie, + // and then we convert to a ConstFst for the decoder and for storing on disk. fst::StdVectorFst dictionary; // For each unigram convert to ints and put in trie for (const auto& word : vocabulary) { @@ -262,18 +268,21 @@ void Scorer::fill_dictionary(const std::vector& vocabulary, bool ad * can greatly increase the size of the FST */ fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; + std::unique_ptr new_dict(new fst::StdVectorFst); /* This makes the FST deterministic, meaning for any string input there's * only one possible state the FST could be in. It is assumed our * dictionary is deterministic when using it. * (lest we'd have to check for multiple transitions at each state) */ - fst::Determinize(dictionary, new_dict); + fst::Determinize(dictionary, new_dict.get()); /* Finds the simplest equivalent fst. This is unnecessary but decreases * memory usage of the dictionary */ - fst::Minimize(new_dict); - this->dictionary.reset(new_dict); + fst::Minimize(new_dict.get()); + + // Now we convert the MutableFst to a ConstFst (Scorer::FstType) via its ctor + std::unique_ptr converted(new FstType(*new_dict)); + this->dictionary = std::move(converted); } diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 9a81d4d6..2e881077 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -40,6 +40,8 @@ public: * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); */ class Scorer { + using FstType = PathTrie::FstType; + public: Scorer(double alpha, double beta, @@ -82,7 +84,7 @@ public: double beta; // pointer to the dictionary of FST - std::unique_ptr dictionary; + std::unique_ptr dictionary; protected: // necessary setup: load language model, fill FST's dictionary diff --git a/native_client/generate_trie.cpp b/native_client/generate_trie.cpp index 0beb9d41..49944c86 100644 --- a/native_client/generate_trie.cpp +++ b/native_client/generate_trie.cpp @@ -3,7 +3,6 @@ #include #include "ctcdecode/scorer.h" -#include "fst/fstlib.h" #include "alphabet.h" using namespace std; diff --git a/native_client/trie_load.cc b/native_client/trie_load.cc index ec0073c9..a719a431 100644 --- a/native_client/trie_load.cc +++ b/native_client/trie_load.cc @@ -3,7 +3,6 @@ #include #include "ctcdecode/scorer.h" -#include "fst/fstlib.h" #include "alphabet.h" using namespace std;