Improve error handling around Scorer loading

This commit is contained in:
Reuben Morais 2020-01-21 22:27:06 +01:00
parent 3b54f54524
commit efbed73d5c
2 changed files with 19 additions and 22 deletions

View File

@ -31,10 +31,8 @@ int
Scorer::init(const std::string& lm_path, Scorer::init(const std::string& lm_path,
const Alphabet& alphabet) const Alphabet& alphabet)
{ {
alphabet_ = alphabet; set_alphabet(alphabet);
setup_char_map(); return load_lm(lm_path);
load_lm(lm_path);
return 0;
} }
int int
@ -46,8 +44,7 @@ Scorer::init(const std::string& lm_path,
return err; return err;
} }
setup_char_map(); setup_char_map();
load_lm(lm_path); return load_lm(lm_path);
return 0;
} }
void void
@ -72,15 +69,14 @@ void Scorer::setup_char_map()
} }
} }
void Scorer::load_lm(const std::string& lm_path) int Scorer::load_lm(const std::string& lm_path)
{ {
// load language model // load language model
const char* filename = lm_path.c_str(); const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
lm::ngram::Config config; lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY; config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config)); language_model_.reset(lm::ngram::LoadVirtual(filename, config));
max_order_ = language_model_->Order();
uint64_t package_size; uint64_t package_size;
{ {
@ -88,26 +84,25 @@ void Scorer::load_lm(const std::string& lm_path)
package_size = util::SizeFile(fd.get()); package_size = util::SizeFile(fd.get());
} }
uint64_t trie_offset = language_model_->GetEndOfSearchOffset(); uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
bool has_trie = package_size > trie_offset; if (package_size <= trie_offset) {
// File ends without a trie structure
if (has_trie) { return 1;
// Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
load_trie(fin, lm_path);
} }
max_order_ = language_model_->Order(); // Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
return load_trie(fin, lm_path);
} }
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path) int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
{ {
int magic; int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic)); fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) { if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating " std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl; "your trie file." << std::endl;
throw 1; return 1;
} }
int version; int version;
@ -122,7 +117,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
std::cerr << "Downgrade your trie file or update your version of DeepSpeech."; std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
} }
std::cerr << std::endl; std::cerr << std::endl;
throw 1; return 1;
} }
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_)); fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
@ -137,6 +132,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
opt.mode = fst::FstReadOptions::MAP; opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path; opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt)); dictionary.reset(FstType::Read(fin, opt));
return 0;
} }
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)

View File

@ -12,6 +12,7 @@
#include "path_trie.h" #include "path_trie.h"
#include "alphabet.h" #include "alphabet.h"
#include "deepspeech.h"
const double OOV_SCORE = -1000.0; const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>"; const std::string START_TOKEN = "<s>";
@ -85,7 +86,7 @@ public:
void fill_dictionary(const std::vector<std::string> &vocabulary); void fill_dictionary(const std::vector<std::string> &vocabulary);
// load language model from given path // load language model from given path
void load_lm(const std::string &lm_path); int load_lm(const std::string &lm_path);
// language model weight // language model weight
double alpha = 0.; double alpha = 0.;
@ -99,7 +100,7 @@ protected:
// necessary setup after setting alphabet // necessary setup after setting alphabet
void setup_char_map(); void setup_char_map();
void load_trie(std::ifstream& fin, const std::string& file_path); int load_trie(std::ifstream& fin, const std::string& file_path);
private: private:
std::unique_ptr<lm::base::Model> language_model_; std::unique_ptr<lm::base::Model> language_model_;