From 6c2cbbd725928b4fba6e90d9bc7381bf3cc28911 Mon Sep 17 00:00:00 2001 From: Alexandre Lissy Date: Tue, 16 Jun 2020 12:40:37 +0200 Subject: [PATCH] Fix #3053: Check output stream when producing scorer --- data/lm/generate_package.py | 8 ++++++-- native_client/ctcdecode/__init__.py | 2 +- native_client/ctcdecode/scorer.cpp | 28 ++++++++++++++++++++++++++-- native_client/ctcdecode/scorer.h | 2 +- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/data/lm/generate_package.py b/data/lm/generate_package.py index 27b00742..30a33fcc 100644 --- a/data/lm/generate_package.py +++ b/data/lm/generate_package.py @@ -61,8 +61,12 @@ def create_bundle( sys.exit(1) scorer.fill_dictionary(list(words)) shutil.copy(lm_path, package_path) - scorer.save_dictionary(package_path, True) # append, not overwrite - print("Package created in {}".format(package_path)) + # append, not overwrite + if scorer.save_dictionary(package_path, True): + print("Package created in {}".format(package_path)) + else: + print("Error when creating {}".format(package_path)) + sys.exit(1) class Tristate(object): diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index c9b917b3..7e3766be 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -47,7 +47,7 @@ class Scorer(swigwrapper.Scorer): return super(Scorer, self).load_lm(lm_path.encode('utf-8')) def save_dictionary(self, save_path, *args, **kwargs): - super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs) + return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs) def ctc_beam_search_decoder(probs_seq, diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 1834c21c..ebf55227 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -146,7 +146,7 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path) return DS_ERR_OK; } -void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) +bool Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) { std::ios::openmode om; if (append_instead_of_overwrite) { @@ -155,15 +155,39 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove om = std::ios::out|std::ios::binary; } std::fstream fout(path, om); + if (!fout ||fout.bad()) { + std::cerr << "Error opening '" << path << "'" << std::endl; + return false; + } fout.write(reinterpret_cast(&MAGIC), sizeof(MAGIC)); + if (fout.bad()) { + std::cerr << "Error writing MAGIC '" << path << "'" << std::endl; + return false; + } fout.write(reinterpret_cast(&FILE_VERSION), sizeof(FILE_VERSION)); + if (fout.bad()) { + std::cerr << "Error writing FILE_VERSION '" << path << "'" << std::endl; + return false; + } fout.write(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); + if (fout.bad()) { + std::cerr << "Error writing is_utf8_mode '" << path << "'" << std::endl; + return false; + } fout.write(reinterpret_cast(&alpha), sizeof(alpha)); + if (fout.bad()) { + std::cerr << "Error writing alpha '" << path << "'" << std::endl; + return false; + } fout.write(reinterpret_cast(&beta), sizeof(beta)); + if (fout.bad()) { + std::cerr << "Error writing beta '" << path << "'" << std::endl; + return false; + } fst::FstWriteOptions opt; opt.align = true; opt.source = path; - dictionary->Write(fout, opt); + return dictionary->Write(fout, opt); } bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label) diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 55f337ed..d2a1c8b3 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -77,7 +77,7 @@ public: void set_alphabet(const Alphabet& alphabet); // save dictionary in file - void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false); + bool save_dictionary(const std::string &path, bool append_instead_of_overwrite=false); // return weather this step represents a boundary where beam scoring should happen bool is_scoring_boundary(PathTrie* prefix, size_t new_label);