Serialize alphabet alongside checkpoint
This commit is contained in:
parent
5afe3c6e59
commit
87f0a371b1
|
@ -69,6 +69,24 @@ Alphabet::init(const char *config_file)
|
|||
return 0;
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::SerializeText()
|
||||
{
|
||||
std::stringstream out;
|
||||
|
||||
out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n"
|
||||
<< "# associated with a numeric label.\n"
|
||||
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
|
||||
<< "# to use '#' as a label.\n";
|
||||
|
||||
for (int label = 0; label < size_; ++label) {
|
||||
out << label_to_str_[label] << "\n";
|
||||
}
|
||||
|
||||
out << "# The last (non-comment) line needs to end with a newline.\n";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::Serialize()
|
||||
{
|
||||
|
|
|
@ -22,6 +22,9 @@ public:
|
|||
// Serialize alphabet into a binary buffer.
|
||||
std::string Serialize();
|
||||
|
||||
// Serialize alphabet into a text representation (ie. config file read by `init`)
|
||||
std::string SerializeText();
|
||||
|
||||
// Deserialize alphabet from a binary buffer.
|
||||
int Deserialize(const char* buffer, const int buffer_size);
|
||||
|
||||
|
|
|
@ -429,6 +429,11 @@ def train():
|
|||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(Config.serialize(), fout, indent=2)
|
||||
|
||||
# Serialize alphabet alongside checkpoint
|
||||
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
|
||||
with open_remote(preserved_alphabet_file, "wb") as fout:
|
||||
fout.write(Config.alphabet.SerializeText())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug("Session opened.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue