Serialize alphabet alongside checkpoint

This commit is contained in:
Reuben Morais 2021-08-23 16:02:13 +02:00
parent 5afe3c6e59
commit 87f0a371b1
3 changed files with 26 additions and 0 deletions

View File

@ -69,6 +69,24 @@ Alphabet::init(const char *config_file)
return 0;
}
std::string
Alphabet::SerializeText()
{
std::stringstream out;
out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n"
<< "# associated with a numeric label.\n"
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
<< "# to use '#' as a label.\n";
for (int label = 0; label < size_; ++label) {
out << label_to_str_[label] << "\n";
}
out << "# The last (non-comment) line needs to end with a newline.\n";
return out.str();
}
std::string
Alphabet::Serialize()
{

View File

@ -22,6 +22,9 @@ public:
// Serialize alphabet into a binary buffer.
std::string Serialize();
// Serialize alphabet into a text representation (ie. config file read by `init`)
std::string SerializeText();
// Deserialize alphabet from a binary buffer.
int Deserialize(const char* buffer, const int buffer_size);

View File

@ -429,6 +429,11 @@ def train():
with open_remote(flags_file, "w") as fout:
json.dump(Config.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
with open_remote(preserved_alphabet_file, "wb") as fout:
fout.write(Config.alphabet.SerializeText())
with tfv1.Session(config=Config.session_config) as session:
log_debug("Session opened.")