Fix #3053: Check output stream when producing scorer

This commit is contained in:
Alexandre Lissy 2020-06-16 12:40:37 +02:00
parent c839ab5355
commit 6c2cbbd725
4 changed files with 34 additions and 6 deletions

View File

@ -61,8 +61,12 @@ def create_bundle(
sys.exit(1)
scorer.fill_dictionary(list(words))
shutil.copy(lm_path, package_path)
scorer.save_dictionary(package_path, True) # append, not overwrite
print("Package created in {}".format(package_path))
# append, not overwrite
if scorer.save_dictionary(package_path, True):
print("Package created in {}".format(package_path))
else:
print("Error when creating {}".format(package_path))
sys.exit(1)
class Tristate(object):

View File

@ -47,7 +47,7 @@ class Scorer(swigwrapper.Scorer):
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
def save_dictionary(self, save_path, *args, **kwargs):
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
def ctc_beam_search_decoder(probs_seq,

View File

@ -146,7 +146,7 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
return DS_ERR_OK;
}
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
bool Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
{
std::ios::openmode om;
if (append_instead_of_overwrite) {
@ -155,15 +155,39 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove
om = std::ios::out|std::ios::binary;
}
std::fstream fout(path, om);
if (!fout ||fout.bad()) {
std::cerr << "Error opening '" << path << "'" << std::endl;
return false;
}
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
if (fout.bad()) {
std::cerr << "Error writing MAGIC '" << path << "'" << std::endl;
return false;
}
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
if (fout.bad()) {
std::cerr << "Error writing FILE_VERSION '" << path << "'" << std::endl;
return false;
}
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
if (fout.bad()) {
std::cerr << "Error writing is_utf8_mode '" << path << "'" << std::endl;
return false;
}
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
if (fout.bad()) {
std::cerr << "Error writing alpha '" << path << "'" << std::endl;
return false;
}
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
if (fout.bad()) {
std::cerr << "Error writing beta '" << path << "'" << std::endl;
return false;
}
fst::FstWriteOptions opt;
opt.align = true;
opt.source = path;
dictionary->Write(fout, opt);
return dictionary->Write(fout, opt);
}
bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)

View File

@ -77,7 +77,7 @@ public:
void set_alphabet(const Alphabet& alphabet);
// save dictionary in file
void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
bool save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
// return weather this step represents a boundary where beam scoring should happen
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);