Serialize alphabet alongside checkpoint
This commit is contained in:
parent
5afe3c6e59
commit
87f0a371b1
|
@ -69,6 +69,24 @@ Alphabet::init(const char *config_file)
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::SerializeText()
|
||||||
|
{
|
||||||
|
std::stringstream out;
|
||||||
|
|
||||||
|
out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n"
|
||||||
|
<< "# associated with a numeric label.\n"
|
||||||
|
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
|
||||||
|
<< "# to use '#' as a label.\n";
|
||||||
|
|
||||||
|
for (int label = 0; label < size_; ++label) {
|
||||||
|
out << label_to_str_[label] << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "# The last (non-comment) line needs to end with a newline.\n";
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
std::string
|
std::string
|
||||||
Alphabet::Serialize()
|
Alphabet::Serialize()
|
||||||
{
|
{
|
||||||
|
|
|
@ -22,6 +22,9 @@ public:
|
||||||
// Serialize alphabet into a binary buffer.
|
// Serialize alphabet into a binary buffer.
|
||||||
std::string Serialize();
|
std::string Serialize();
|
||||||
|
|
||||||
|
// Serialize alphabet into a text representation (ie. config file read by `init`)
|
||||||
|
std::string SerializeText();
|
||||||
|
|
||||||
// Deserialize alphabet from a binary buffer.
|
// Deserialize alphabet from a binary buffer.
|
||||||
int Deserialize(const char* buffer, const int buffer_size);
|
int Deserialize(const char* buffer, const int buffer_size);
|
||||||
|
|
||||||
|
|
|
@ -429,6 +429,11 @@ def train():
|
||||||
with open_remote(flags_file, "w") as fout:
|
with open_remote(flags_file, "w") as fout:
|
||||||
json.dump(Config.serialize(), fout, indent=2)
|
json.dump(Config.serialize(), fout, indent=2)
|
||||||
|
|
||||||
|
# Serialize alphabet alongside checkpoint
|
||||||
|
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
|
||||||
|
with open_remote(preserved_alphabet_file, "wb") as fout:
|
||||||
|
fout.write(Config.alphabet.SerializeText())
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
log_debug("Session opened.")
|
log_debug("Session opened.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue