diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index dfe2824a..5bd4da8e 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -88,46 +88,56 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path) const char* filename = lm_path.c_str(); VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path"); - bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0; - // VALID_CHECK(has_trie, "Invalid trie path"); - lm::ngram::Config config; config.load_method = util::LoadMethod::LAZY; language_model_.reset(lm::ngram::LoadVirtual(filename, config)); + uint64_t package_size; + { + util::scoped_fd fd(util::OpenReadOrThrow(filename)); + package_size = util::SizeFile(fd.get()); + } + uint64_t trie_offset = language_model_->GetEndOfSearchOffset(); + bool has_trie = package_size > trie_offset; + if (has_trie) { // Read metadata and trie from file - std::ifstream fin(trie_path, std::ios::binary); - - int magic; - fin.read(reinterpret_cast(&magic), sizeof(magic)); - if (magic != MAGIC) { - std::cerr << "Error: Can't parse trie file, invalid header. Try updating " - "your trie file." << std::endl; - throw 1; - } - - int version; - fin.read(reinterpret_cast(&version), sizeof(version)); - if (version != FILE_VERSION) { - std::cerr << "Error: Trie file version mismatch (" << version - << " instead of expected " << FILE_VERSION - << "). Update your trie file." - << std::endl; - throw 1; - } - - fin.read(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); - - fst::FstReadOptions opt; - opt.mode = fst::FstReadOptions::MAP; - opt.source = trie_path; - dictionary.reset(FstType::Read(fin, opt)); + std::ifstream fin(lm_path, std::ios::binary); + fin.seekg(trie_offset); + load_trie(fin, lm_path); } max_order_ = language_model_->Order(); } +void Scorer::load_trie(std::ifstream& fin, const std::string& file_path) +{ + int magic; + fin.read(reinterpret_cast(&magic), sizeof(magic)); + if (magic != MAGIC) { + std::cerr << "Error: Can't parse trie file, invalid header. Try updating " + "your trie file." << std::endl; + throw 1; + } + + int version; + fin.read(reinterpret_cast(&version), sizeof(version)); + if (version != FILE_VERSION) { + std::cerr << "Error: Trie file version mismatch (" << version + << " instead of expected " << FILE_VERSION + << "). Update your trie file." + << std::endl; + throw 1; + } + + fin.read(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); + + fst::FstReadOptions opt; + opt.mode = fst::FstReadOptions::MAP; + opt.source = file_path; + dictionary.reset(FstType::Read(fin, opt)); +} + void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) { std::ios::openmode om; diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 17bd1028..e4b86c9a 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -118,6 +118,8 @@ protected: // necessary setup after setting alphabet void setup_char_map(); + void load_trie(std::ifstream& fin, const std::string& file_path); + private: std::unique_ptr language_model_; bool is_utf8_mode_ = true; diff --git a/native_client/kenlm/lm/model.cc b/native_client/kenlm/lm/model.cc index a5a16bf8..fc4e374c 100644 --- a/native_client/kenlm/lm/model.cc +++ b/native_client/kenlm/lm/model.cc @@ -226,6 +226,10 @@ template FullScoreReturn GenericModel uint64_t GenericModel::GetEndOfSearchOffset() const { + return backing_.VocabStringReadingOffset(); +} + namespace { // Do a paraonoid copy of history, assuming new_word has already been copied // (hence the -1). out_state.length could be zero so I avoided using diff --git a/native_client/kenlm/lm/model.hh b/native_client/kenlm/lm/model.hh index b2bbe399..9b7206e8 100644 --- a/native_client/kenlm/lm/model.hh +++ b/native_client/kenlm/lm/model.hh @@ -102,6 +102,8 @@ template class GenericModel : public base::Mod return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; } + uint64_t GetEndOfSearchOffset() const; + private: FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; diff --git a/native_client/kenlm/lm/virtual_interface.hh b/native_client/kenlm/lm/virtual_interface.hh index ea491fbf..91abe90e 100644 --- a/native_client/kenlm/lm/virtual_interface.hh +++ b/native_client/kenlm/lm/virtual_interface.hh @@ -137,6 +137,8 @@ class Model { const Vocabulary &BaseVocabulary() const { return *base_vocab_; } + virtual uint64_t GetEndOfSearchOffset() const = 0; + private: template friend class ModelFacade; explicit Model(size_t state_size) : state_size_(state_size) {}