Load combined format from Scorer

This commit is contained in:
Reuben Morais 2020-01-16 16:16:06 +01:00
parent 214b50f490
commit b33d90b7bd
5 changed files with 49 additions and 29 deletions

View File

@ -88,46 +88,56 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
// VALID_CHECK(has_trie, "Invalid trie path");
lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(filename));
package_size = util::SizeFile(fd.get());
}
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
bool has_trie = package_size > trie_offset;
if (has_trie) {
// Read metadata and trie from file
std::ifstream fin(trie_path, std::ios::binary);
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl;
throw 1;
}
int version;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != FILE_VERSION) {
std::cerr << "Error: Trie file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). Update your trie file."
<< std::endl;
throw 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path;
dictionary.reset(FstType::Read(fin, opt));
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
load_trie(fin, lm_path);
}
max_order_ = language_model_->Order();
}
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
{
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl;
throw 1;
}
int version;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != FILE_VERSION) {
std::cerr << "Error: Trie file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). Update your trie file."
<< std::endl;
throw 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
}
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
{
std::ios::openmode om;

View File

@ -118,6 +118,8 @@ protected:
// necessary setup after setting alphabet
void setup_char_map();
void load_trie(std::ifstream& fin, const std::string& file_path);
private:
std::unique_ptr<lm::base::Model> language_model_;
bool is_utf8_mode_ = true;

View File

@ -226,6 +226,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
return ret;
}
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::GetEndOfSearchOffset() const {
return backing_.VocabStringReadingOffset();
}
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using

View File

@ -102,6 +102,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
}
uint64_t GetEndOfSearchOffset() const;
private:
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;

View File

@ -137,6 +137,8 @@ class Model {
const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
virtual uint64_t GetEndOfSearchOffset() const = 0;
private:
template <class T, class U, class V> friend class ModelFacade;
explicit Model(size_t state_size) : state_size_(state_size) {}