Improve error handling around Scorer loading
This commit is contained in:
parent
3b54f54524
commit
efbed73d5c
@ -31,10 +31,8 @@ int
|
|||||||
Scorer::init(const std::string& lm_path,
|
Scorer::init(const std::string& lm_path,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
alphabet_ = alphabet;
|
set_alphabet(alphabet);
|
||||||
setup_char_map();
|
return load_lm(lm_path);
|
||||||
load_lm(lm_path);
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
@ -46,8 +44,7 @@ Scorer::init(const std::string& lm_path,
|
|||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
setup_char_map();
|
setup_char_map();
|
||||||
load_lm(lm_path);
|
return load_lm(lm_path);
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -72,15 +69,14 @@ void Scorer::setup_char_map()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::load_lm(const std::string& lm_path)
|
int Scorer::load_lm(const std::string& lm_path)
|
||||||
{
|
{
|
||||||
// load language model
|
// load language model
|
||||||
const char* filename = lm_path.c_str();
|
const char* filename = lm_path.c_str();
|
||||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
|
||||||
|
|
||||||
lm::ngram::Config config;
|
lm::ngram::Config config;
|
||||||
config.load_method = util::LoadMethod::LAZY;
|
config.load_method = util::LoadMethod::LAZY;
|
||||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||||
|
max_order_ = language_model_->Order();
|
||||||
|
|
||||||
uint64_t package_size;
|
uint64_t package_size;
|
||||||
{
|
{
|
||||||
@ -88,26 +84,25 @@ void Scorer::load_lm(const std::string& lm_path)
|
|||||||
package_size = util::SizeFile(fd.get());
|
package_size = util::SizeFile(fd.get());
|
||||||
}
|
}
|
||||||
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
|
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
|
||||||
bool has_trie = package_size > trie_offset;
|
if (package_size <= trie_offset) {
|
||||||
|
// File ends without a trie structure
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
if (has_trie) {
|
|
||||||
// Read metadata and trie from file
|
// Read metadata and trie from file
|
||||||
std::ifstream fin(lm_path, std::ios::binary);
|
std::ifstream fin(lm_path, std::ios::binary);
|
||||||
fin.seekg(trie_offset);
|
fin.seekg(trie_offset);
|
||||||
load_trie(fin, lm_path);
|
return load_trie(fin, lm_path);
|
||||||
}
|
|
||||||
|
|
||||||
max_order_ = language_model_->Order();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||||
{
|
{
|
||||||
int magic;
|
int magic;
|
||||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||||
if (magic != MAGIC) {
|
if (magic != MAGIC) {
|
||||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||||
"your trie file." << std::endl;
|
"your trie file." << std::endl;
|
||||||
throw 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int version;
|
int version;
|
||||||
@ -122,7 +117,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
|||||||
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
|
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
|
||||||
}
|
}
|
||||||
std::cerr << std::endl;
|
std::cerr << std::endl;
|
||||||
throw 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
@ -137,6 +132,7 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
|||||||
opt.mode = fst::FstReadOptions::MAP;
|
opt.mode = fst::FstReadOptions::MAP;
|
||||||
opt.source = file_path;
|
opt.source = file_path;
|
||||||
dictionary.reset(FstType::Read(fin, opt));
|
dictionary.reset(FstType::Read(fin, opt));
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
|
|
||||||
#include "path_trie.h"
|
#include "path_trie.h"
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
#include "deepspeech.h"
|
||||||
|
|
||||||
const double OOV_SCORE = -1000.0;
|
const double OOV_SCORE = -1000.0;
|
||||||
const std::string START_TOKEN = "<s>";
|
const std::string START_TOKEN = "<s>";
|
||||||
@ -85,7 +86,7 @@ public:
|
|||||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||||
|
|
||||||
// load language model from given path
|
// load language model from given path
|
||||||
void load_lm(const std::string &lm_path);
|
int load_lm(const std::string &lm_path);
|
||||||
|
|
||||||
// language model weight
|
// language model weight
|
||||||
double alpha = 0.;
|
double alpha = 0.;
|
||||||
@ -99,7 +100,7 @@ protected:
|
|||||||
// necessary setup after setting alphabet
|
// necessary setup after setting alphabet
|
||||||
void setup_char_map();
|
void setup_char_map();
|
||||||
|
|
||||||
void load_trie(std::ifstream& fin, const std::string& file_path);
|
int load_trie(std::ifstream& fin, const std::string& file_path);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<lm::base::Model> language_model_;
|
std::unique_ptr<lm::base::Model> language_model_;
|
||||||
|
Loading…
Reference in New Issue
Block a user