Load combined format from Scorer
This commit is contained in:
parent
214b50f490
commit
b33d90b7bd
@ -88,46 +88,56 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
|||||||
const char* filename = lm_path.c_str();
|
const char* filename = lm_path.c_str();
|
||||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||||
|
|
||||||
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
|
|
||||||
// VALID_CHECK(has_trie, "Invalid trie path");
|
|
||||||
|
|
||||||
lm::ngram::Config config;
|
lm::ngram::Config config;
|
||||||
config.load_method = util::LoadMethod::LAZY;
|
config.load_method = util::LoadMethod::LAZY;
|
||||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||||
|
|
||||||
|
uint64_t package_size;
|
||||||
|
{
|
||||||
|
util::scoped_fd fd(util::OpenReadOrThrow(filename));
|
||||||
|
package_size = util::SizeFile(fd.get());
|
||||||
|
}
|
||||||
|
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
|
||||||
|
bool has_trie = package_size > trie_offset;
|
||||||
|
|
||||||
if (has_trie) {
|
if (has_trie) {
|
||||||
// Read metadata and trie from file
|
// Read metadata and trie from file
|
||||||
std::ifstream fin(trie_path, std::ios::binary);
|
std::ifstream fin(lm_path, std::ios::binary);
|
||||||
|
fin.seekg(trie_offset);
|
||||||
int magic;
|
load_trie(fin, lm_path);
|
||||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
|
||||||
if (magic != MAGIC) {
|
|
||||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
|
||||||
"your trie file." << std::endl;
|
|
||||||
throw 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int version;
|
|
||||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
|
||||||
if (version != FILE_VERSION) {
|
|
||||||
std::cerr << "Error: Trie file version mismatch (" << version
|
|
||||||
<< " instead of expected " << FILE_VERSION
|
|
||||||
<< "). Update your trie file."
|
|
||||||
<< std::endl;
|
|
||||||
throw 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
|
||||||
|
|
||||||
fst::FstReadOptions opt;
|
|
||||||
opt.mode = fst::FstReadOptions::MAP;
|
|
||||||
opt.source = trie_path;
|
|
||||||
dictionary.reset(FstType::Read(fin, opt));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
max_order_ = language_model_->Order();
|
max_order_ = language_model_->Order();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||||
|
{
|
||||||
|
int magic;
|
||||||
|
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||||
|
if (magic != MAGIC) {
|
||||||
|
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||||
|
"your trie file." << std::endl;
|
||||||
|
throw 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int version;
|
||||||
|
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||||
|
if (version != FILE_VERSION) {
|
||||||
|
std::cerr << "Error: Trie file version mismatch (" << version
|
||||||
|
<< " instead of expected " << FILE_VERSION
|
||||||
|
<< "). Update your trie file."
|
||||||
|
<< std::endl;
|
||||||
|
throw 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
|
||||||
|
fst::FstReadOptions opt;
|
||||||
|
opt.mode = fst::FstReadOptions::MAP;
|
||||||
|
opt.source = file_path;
|
||||||
|
dictionary.reset(FstType::Read(fin, opt));
|
||||||
|
}
|
||||||
|
|
||||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||||
{
|
{
|
||||||
std::ios::openmode om;
|
std::ios::openmode om;
|
||||||
|
@ -118,6 +118,8 @@ protected:
|
|||||||
// necessary setup after setting alphabet
|
// necessary setup after setting alphabet
|
||||||
void setup_char_map();
|
void setup_char_map();
|
||||||
|
|
||||||
|
void load_trie(std::ifstream& fin, const std::string& file_path);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<lm::base::Model> language_model_;
|
std::unique_ptr<lm::base::Model> language_model_;
|
||||||
bool is_utf8_mode_ = true;
|
bool is_utf8_mode_ = true;
|
||||||
|
@ -226,6 +226,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::GetEndOfSearchOffset() const {
|
||||||
|
return backing_.VocabStringReadingOffset();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Do a paraonoid copy of history, assuming new_word has already been copied
|
// Do a paraonoid copy of history, assuming new_word has already been copied
|
||||||
// (hence the -1). out_state.length could be zero so I avoided using
|
// (hence the -1). out_state.length could be zero so I avoided using
|
||||||
|
@ -102,6 +102,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
|
|||||||
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
|
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint64_t GetEndOfSearchOffset() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
|
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
|
||||||
|
|
||||||
|
@ -137,6 +137,8 @@ class Model {
|
|||||||
|
|
||||||
const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
|
const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
|
||||||
|
|
||||||
|
virtual uint64_t GetEndOfSearchOffset() const = 0;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <class T, class U, class V> friend class ModelFacade;
|
template <class T, class U, class V> friend class ModelFacade;
|
||||||
explicit Model(size_t state_size) : state_size_(state_size) {}
|
explicit Model(size_t state_size) : state_size_(state_size) {}
|
||||||
|
Loading…
Reference in New Issue
Block a user