Load combined format from Scorer

This commit is contained in:
Reuben Morais 2020-01-16 16:16:06 +01:00
parent 214b50f490
commit b33d90b7bd
5 changed files with 49 additions and 29 deletions

View File

@ -88,17 +88,30 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
const char* filename = lm_path.c_str(); const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path"); VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
// VALID_CHECK(has_trie, "Invalid trie path");
lm::ngram::Config config; lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY; config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config)); language_model_.reset(lm::ngram::LoadVirtual(filename, config));
uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(filename));
package_size = util::SizeFile(fd.get());
}
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
bool has_trie = package_size > trie_offset;
if (has_trie) { if (has_trie) {
// Read metadata and trie from file // Read metadata and trie from file
std::ifstream fin(trie_path, std::ios::binary); std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
load_trie(fin, lm_path);
}
max_order_ = language_model_->Order();
}
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
{
int magic; int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic)); fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) { if (magic != MAGIC) {
@ -121,13 +134,10 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
fst::FstReadOptions opt; fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP; opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path; opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt)); dictionary.reset(FstType::Read(fin, opt));
} }
max_order_ = language_model_->Order();
}
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
{ {
std::ios::openmode om; std::ios::openmode om;

View File

@ -118,6 +118,8 @@ protected:
// necessary setup after setting alphabet // necessary setup after setting alphabet
void setup_char_map(); void setup_char_map();
void load_trie(std::ifstream& fin, const std::string& file_path);
private: private:
std::unique_ptr<lm::base::Model> language_model_; std::unique_ptr<lm::base::Model> language_model_;
bool is_utf8_mode_ = true; bool is_utf8_mode_ = true;

View File

@ -226,6 +226,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
return ret; return ret;
} }
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::GetEndOfSearchOffset() const {
return backing_.VocabStringReadingOffset();
}
namespace { namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied // Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using // (hence the -1). out_state.length could be zero so I avoided using

View File

@ -102,6 +102,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
} }
uint64_t GetEndOfSearchOffset() const;
private: private:
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;

View File

@ -137,6 +137,8 @@ class Model {
const Vocabulary &BaseVocabulary() const { return *base_vocab_; } const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
virtual uint64_t GetEndOfSearchOffset() const = 0;
private: private:
template <class T, class U, class V> friend class ModelFacade; template <class T, class U, class V> friend class ModelFacade;
explicit Model(size_t state_size) : state_size_(state_size) {} explicit Model(size_t state_size) : state_size_(state_size) {}