Fix #3053: Check output stream when producing scorer
This commit is contained in:
parent
c839ab5355
commit
6c2cbbd725
@ -61,8 +61,12 @@ def create_bundle(
|
||||
sys.exit(1)
|
||||
scorer.fill_dictionary(list(words))
|
||||
shutil.copy(lm_path, package_path)
|
||||
scorer.save_dictionary(package_path, True) # append, not overwrite
|
||||
print("Package created in {}".format(package_path))
|
||||
# append, not overwrite
|
||||
if scorer.save_dictionary(package_path, True):
|
||||
print("Package created in {}".format(package_path))
|
||||
else:
|
||||
print("Error when creating {}".format(package_path))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class Tristate(object):
|
||||
|
@ -47,7 +47,7 @@ class Scorer(swigwrapper.Scorer):
|
||||
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
||||
|
||||
def save_dictionary(self, save_path, *args, **kwargs):
|
||||
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
||||
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
|
@ -146,7 +146,7 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||
bool Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||
{
|
||||
std::ios::openmode om;
|
||||
if (append_instead_of_overwrite) {
|
||||
@ -155,15 +155,39 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove
|
||||
om = std::ios::out|std::ios::binary;
|
||||
}
|
||||
std::fstream fout(path, om);
|
||||
if (!fout ||fout.bad()) {
|
||||
std::cerr << "Error opening '" << path << "'" << std::endl;
|
||||
return false;
|
||||
}
|
||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||
if (fout.bad()) {
|
||||
std::cerr << "Error writing MAGIC '" << path << "'" << std::endl;
|
||||
return false;
|
||||
}
|
||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||
if (fout.bad()) {
|
||||
std::cerr << "Error writing FILE_VERSION '" << path << "'" << std::endl;
|
||||
return false;
|
||||
}
|
||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
if (fout.bad()) {
|
||||
std::cerr << "Error writing is_utf8_mode '" << path << "'" << std::endl;
|
||||
return false;
|
||||
}
|
||||
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
|
||||
if (fout.bad()) {
|
||||
std::cerr << "Error writing alpha '" << path << "'" << std::endl;
|
||||
return false;
|
||||
}
|
||||
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
|
||||
if (fout.bad()) {
|
||||
std::cerr << "Error writing beta '" << path << "'" << std::endl;
|
||||
return false;
|
||||
}
|
||||
fst::FstWriteOptions opt;
|
||||
opt.align = true;
|
||||
opt.source = path;
|
||||
dictionary->Write(fout, opt);
|
||||
return dictionary->Write(fout, opt);
|
||||
}
|
||||
|
||||
bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
|
||||
|
@ -77,7 +77,7 @@ public:
|
||||
void set_alphabet(const Alphabet& alphabet);
|
||||
|
||||
// save dictionary in file
|
||||
void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
||||
bool save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
||||
|
||||
// return weather this step represents a boundary where beam scoring should happen
|
||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||
|
Loading…
x
Reference in New Issue
Block a user