Fix #3053: Check output stream when producing scorer
This commit is contained in:
parent
c839ab5355
commit
6c2cbbd725
@ -61,8 +61,12 @@ def create_bundle(
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
scorer.fill_dictionary(list(words))
|
scorer.fill_dictionary(list(words))
|
||||||
shutil.copy(lm_path, package_path)
|
shutil.copy(lm_path, package_path)
|
||||||
scorer.save_dictionary(package_path, True) # append, not overwrite
|
# append, not overwrite
|
||||||
|
if scorer.save_dictionary(package_path, True):
|
||||||
print("Package created in {}".format(package_path))
|
print("Package created in {}".format(package_path))
|
||||||
|
else:
|
||||||
|
print("Error when creating {}".format(package_path))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
class Tristate(object):
|
class Tristate(object):
|
||||||
|
@ -47,7 +47,7 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
||||||
|
|
||||||
def save_dictionary(self, save_path, *args, **kwargs):
|
def save_dictionary(self, save_path, *args, **kwargs):
|
||||||
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
|
@ -146,7 +146,7 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
|||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
bool Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||||
{
|
{
|
||||||
std::ios::openmode om;
|
std::ios::openmode om;
|
||||||
if (append_instead_of_overwrite) {
|
if (append_instead_of_overwrite) {
|
||||||
@ -155,15 +155,39 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove
|
|||||||
om = std::ios::out|std::ios::binary;
|
om = std::ios::out|std::ios::binary;
|
||||||
}
|
}
|
||||||
std::fstream fout(path, om);
|
std::fstream fout(path, om);
|
||||||
|
if (!fout ||fout.bad()) {
|
||||||
|
std::cerr << "Error opening '" << path << "'" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||||
|
if (fout.bad()) {
|
||||||
|
std::cerr << "Error writing MAGIC '" << path << "'" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||||
|
if (fout.bad()) {
|
||||||
|
std::cerr << "Error writing FILE_VERSION '" << path << "'" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
if (fout.bad()) {
|
||||||
|
std::cerr << "Error writing is_utf8_mode '" << path << "'" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
|
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
|
||||||
|
if (fout.bad()) {
|
||||||
|
std::cerr << "Error writing alpha '" << path << "'" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
|
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
|
||||||
|
if (fout.bad()) {
|
||||||
|
std::cerr << "Error writing beta '" << path << "'" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
fst::FstWriteOptions opt;
|
fst::FstWriteOptions opt;
|
||||||
opt.align = true;
|
opt.align = true;
|
||||||
opt.source = path;
|
opt.source = path;
|
||||||
dictionary->Write(fout, opt);
|
return dictionary->Write(fout, opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
|
bool Scorer::is_scoring_boundary(PathTrie* prefix, size_t new_label)
|
||||||
|
@ -77,7 +77,7 @@ public:
|
|||||||
void set_alphabet(const Alphabet& alphabet);
|
void set_alphabet(const Alphabet& alphabet);
|
||||||
|
|
||||||
// save dictionary in file
|
// save dictionary in file
|
||||||
void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
bool save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
||||||
|
|
||||||
// return weather this step represents a boundary where beam scoring should happen
|
// return weather this step represents a boundary where beam scoring should happen
|
||||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user