Load combined format from Scorer
This commit is contained in:
parent
214b50f490
commit
b33d90b7bd
@ -88,46 +88,56 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||
|
||||
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
|
||||
// VALID_CHECK(has_trie, "Invalid trie path");
|
||||
|
||||
lm::ngram::Config config;
|
||||
config.load_method = util::LoadMethod::LAZY;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||
|
||||
uint64_t package_size;
|
||||
{
|
||||
util::scoped_fd fd(util::OpenReadOrThrow(filename));
|
||||
package_size = util::SizeFile(fd.get());
|
||||
}
|
||||
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
|
||||
bool has_trie = package_size > trie_offset;
|
||||
|
||||
if (has_trie) {
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(trie_path, std::ios::binary);
|
||||
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||
"your trie file." << std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
int version;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Trie file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). Update your trie file."
|
||||
<< std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = trie_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
std::ifstream fin(lm_path, std::ios::binary);
|
||||
fin.seekg(trie_offset);
|
||||
load_trie(fin, lm_path);
|
||||
}
|
||||
|
||||
max_order_ = language_model_->Order();
|
||||
}
|
||||
|
||||
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||
{
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||
"your trie file." << std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
int version;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Trie file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). Update your trie file."
|
||||
<< std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = file_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
}
|
||||
|
||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||
{
|
||||
std::ios::openmode om;
|
||||
|
@ -118,6 +118,8 @@ protected:
|
||||
// necessary setup after setting alphabet
|
||||
void setup_char_map();
|
||||
|
||||
void load_trie(std::ifstream& fin, const std::string& file_path);
|
||||
|
||||
private:
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
bool is_utf8_mode_ = true;
|
||||
|
@ -226,6 +226,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::GetEndOfSearchOffset() const {
|
||||
return backing_.VocabStringReadingOffset();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Do a paraonoid copy of history, assuming new_word has already been copied
|
||||
// (hence the -1). out_state.length could be zero so I avoided using
|
||||
|
@ -102,6 +102,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
|
||||
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
|
||||
}
|
||||
|
||||
uint64_t GetEndOfSearchOffset() const;
|
||||
|
||||
private:
|
||||
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
|
||||
|
||||
|
@ -137,6 +137,8 @@ class Model {
|
||||
|
||||
const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
|
||||
|
||||
virtual uint64_t GetEndOfSearchOffset() const = 0;
|
||||
|
||||
private:
|
||||
template <class T, class U, class V> friend class ModelFacade;
|
||||
explicit Model(size_t state_size) : state_size_(state_size) {}
|
||||
|
Loading…
Reference in New Issue
Block a user